『ignite』PyTorch好用的工具包

尽管 PyTorch 已经为我们实现神经网络提供了不少便利,但是人的惰性是无极限的,这里介绍一个进一步抽象的工具包——ignite,它将 PyTorch 训练过程更加简化了。

1. 安装

pip install pytorch-ignite

2. 基础示例

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator

from ignite.metrics import Accuracy, Loss

model = Net()

train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8)

criterion = nn.NLLLoss()

trainer = create_supervised_trainer(model, optimizer, criterion)

val_metrics = {

"accuracy": Accuracy(),

"nll": Loss(criterion)

}

evaluator = create_supervised_evaluator(model, metrics=val_metrics)

@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))

def log_training_loss(trainer):

print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output))

@trainer.on(Events.EPOCH_COMPLETED)

def log_training_results(trainer):

evaluator.run(train_loader)

metrics = evaluator.state.metrics

print("Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"

.format(trainer.state.epoch, metrics["accuracy"], metrics["nll"]))

@trainer.on(Events.EPOCH_COMPLETED)

def log_validation_results(trainer):

evaluator.run(val_loader)

metrics = evaluator.state.metrics

print("Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"

.format(trainer.state.epoch, metrics["accuracy"], metrics["nll"]))

trainer.run(train_loader, max_epochs=100)

显然,这里先创建网络模型,Dataloader,优化器以及目标函数,然后用 ignite 的方法 create_supervised_trainer 和 create_supervised_evaluator 简化以往繁琐的循环写法,另外,ignite 还提供了面向切面的处理方法,可以在epoch、iteration等开始前、结束后位置执行你希望的操作

3. Engine

这是 ignite 的核心类,它是一种抽象,它在提供的数据上循环给定的次数,执行处理函数并返回结果

while epoch < max_epochs:

# run an epoch on data

data_iter = iter(data)

while True:

try:

batch = next(data_iter)

output = process_function(batch)

iter_counter += 1

except StopIteration:

data_iter = iter(data)

if iter_counter == epoch_length:

break

因此,模型训练器只是一个引擎,它在训练数据集上循环多次并更新模型参数。例如:

def train_step(trainer, batch):

model.train()

optimizer.zero_grad()

x, y = prepare_batch(batch)

y_pred = model(x)

loss = loss_fn(y_pred, y)

loss.backward()

optimizer.step()

return loss.item()

trainer = Engine(train_step)

trainer.run(data, max_epochs=100)

【例 1】创建一个基本的训练器

def update_model(engine, batch):

inputs, targets = batch

optimizer.zero_grad()

outputs = model(inputs)

loss = criterion(outputs, targets)

loss.backward()

optimizer.step()

return loss.item()

trainer = Engine(update_model)

@trainer.on(Events.ITERATION_COMPLETED(every=100))

def log_training(engine):

batch_loss = engine.state.output

lr = optimizer.param_groups[0]['lr']

e = engine.state.epoch

n = engine.state.max_epochs

i = engine.state.iteration

print("Epoch {}/{} : {} - batch loss: {}, lr: {}".format(e, n, i, batch_loss, lr))

trainer.run(data_loader, max_epochs=5)

【例 2】创建一个基本的评估器并计算指标

from ignite.metrics import Accuracy

def predict_on_batch(engine, batch)

model.eval()

with torch.no_grad():

x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)

y_pred = model(x)

return y_pred, y

evaluator = Engine(predict_on_batch)

Accuracy().attach(evaluator, "val_acc")

evaluator.run(val_dataloader)

【例 3】在训练数据集上计算图像均值/标准差

from ignite.metrics import Average

def compute_mean_std(engine, batch):

b, c, *_ = batch['image'].shape

data = batch['image'].reshape(b, c, -1).to(dtype=torch.float64)

mean = torch.mean(data, dim=-1).sum(dim=0)

mean2 = torch.mean(data ** 2, dim=-1).sum(dim=0)

return {"mean": mean, "mean^2": mean2}

compute_engine = Engine(compute_mean_std)

img_mean = Average(output_transform=lambda output: output['mean'])

img_mean.attach(compute_engine, 'mean')

img_mean2 = Average(output_transform=lambda output: output['mean^2'])

img_mean2.attach(compute_engine, 'mean2')

state = compute_engine.run(train_loader)

state.metrics['std'] = torch.sqrt(state.metrics['mean2'] - state.metrics['mean'] ** 2)

mean = state.metrics['mean'].tolist()

std = state.metrics['std'].tolist()

【例 4】从状态恢复引擎的运行。用户可以加载state_dict并从加载的状态开始运行引擎

# Restore from an epoch

state_dict = {"epoch": 3, "max_epochs": 100, "epoch_length": len(data_loader)}

# or an iteration

# state_dict = {"iteration": 500, "max_epochs": 100, "epoch_length": len(data_loader)}

trainer = Engine(...)

trainer.load_state_dict(state_dict)

trainer.run(data)

Engine 对象还有以下方法:

terminate():向引擎发送终止信号,以便它在当前迭代之后完全终止运行。

terminate_epoch():向引擎发送终止信号,以便它在当前迭代之后终止当前epoch。

ignite.engine.create_supervised_trainer:

工厂功能,用于创建受监管模型的trainer。

def create_supervised_trainer(

model: torch.nn.Module,

optimizer: torch.optim.Optimizer,

loss_fn: Union[Callable, torch.nn.Module],

device: Optional[Union[str, torch.device]] = None,

non_blocking: bool = False,

prepare_batch: Callable = _prepare_batch,

output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),

deterministic: bool = False,

) -> Engine:

model:要训练的模型

optimizer:要使用的优化器

loss_fn:要使用的损失函数

device:设备类型规范(默认值:无)Device can be CPU, GPU or TPU

non_blocking:如果为True且此副本位于CPU和GPU之间,则该副本可能相对于主机异步发生。在其他情况下,此参数无效。

prepare_batch:接收(batch,device,non_blocking)并输出张量元组(batch_x,batch_y)的函数

output_transform:接收“ x”,“ y”,“ y_pred”,“ loss”并返回要分配给引擎状态的值的函数。每次迭代后输出。默认为returning loss.item()

deterministic:如果为True,则返回类型为确定性的引擎DeterministicEngine,否则返回 Engine (默认值:False)

类似地还有ignite.engine.create_supervised_evaluator,其参数少于trainer

def create_supervised_evaluator(

model: torch.nn.Module,

metrics: Optional[Dict[str, Metric]] = None,

device: Optional[Union[str, torch.device]] = None,

non_blocking: bool = False,

prepare_batch: Callable = _prepare_batch,

output_transform: Callable = lambda x, y, y_pred: (y_pred, y),

) -> Engine:

model:训练好的模型

metrics:指标名称到指标的映射

device:设备类型规范(默认值:无)Device can be CPU, GPU or TPU

output_transform:接收“ x”,“ y”,“ y_pred” 并在每次迭代后返回要分配给引擎state.output的值的函数。默认为返回值(y_pred,y,),它适合度量期望的输出。如果更改它,则应在指标中使用output_transform

【例 5】断点恢复训练

有可能从一个检查点恢复训练,并大致重现原来的运行行为。使用Ignite,这可以通过使用检查点处理程序轻松完成。引擎提供了两个方法来序列化和反序列化其内部状态state_dict()和load_state_dict()。除了序列化模型,优化器,lr调度器等用户可以存储培训器,然后恢复培训。例如

from ignite.handlers import Checkpoint, DiskSaver

trainer = ...

model = ...

optimizer = ...

lr_scheduler = ...

data_loader = ...

to_save = {'trainer': trainer,

'model': model,

'optimizer': optimizer,

'lr_scheduler': lr_scheduler}

handler = Checkpoint(to_save, DiskSaver('/tmp/training', create_dir=True))

trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)

trainer.run(data_loader, max_epochs=100)

然后,我们可以从最后一个检查点恢复训练。

from ignite.handlers import Checkpoint

trainer = ...

model = ...

optimizer = ...

lr_scheduler = ...

data_loader = ...

to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}

checkpoint = torch.load(checkpoint_file)

Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)

trainer.run(train_loader, max_epochs=100)

4. Events & Handlers

为了提高 Engine 灵活性,引入了一个事件系统,该系统促进了运行的每个步骤之间的交互:

engine is started/completedepoch is started/completedbatch iteration is started/completed

详细的事件可以进ignite.engine.events看

下面展示了 Engine 的 run() 方法执行的细节:

fire_event(Events.STARTED)

while epoch < max_epochs:

fire_event(Events.EPOCH_STARTED)

# run once on data

for batch in data:

fire_event(Events.ITERATION_STARTED)

output = process_function(batch)

fire_event(Events.ITERATION_COMPLETED)

fire_event(Events.EPOCH_COMPLETED)

fire_event(Events.COMPLETED)

上述代码展示了各个事件执行的位置

使用事件的方法又2种:add_event_handler() 或 装饰器 on

trainer = Engine(update_model)

trainer.add_event_handler(Events.STARTED, lambda _: print("Start training"))

# or

@trainer.on(Events.STARTED)

def on_training_started(engine):

print("Another message of start training")

# or even simpler, use only what you need !

@trainer.on(Events.STARTED)

def on_training_started():

print("Another message of start training")

# attach handler with args, kwargs

mydata = [1, 2, 3, 4]

def on_training_ended(data):

print("Training is ended. mydata={}".format(data))

trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)

利用add_event_handler()方法还可以动态添加事件:

model = ...

train_loader, validation_loader, test_loader = ...

trainer = create_supervised_trainer(model, optimizer, loss)

evaluator = create_supervised_evaluator(model, metrics={"acc": Accuracy()})

def log_metrics(engine, title):

print("Epoch: {} - {} accuracy: {:.2f}"

.format(trainer.state.epoch, title, engine.state.metrics["acc"]))

@trainer.on(Events.EPOCH_COMPLETED)

def evaluate(trainer):

with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "train"):

evaluator.run(train_loader)

with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "validation"):

evaluator.run(validation_loader)

with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "test"):

evaluator.run(test_loader)

trainer.run(train_loader, max_epochs=100)

还可以将事件处理程序配置为以用户模式调用:每第n个事件一次,或使用自定义事件过滤功能:

model = ...

train_loader, validation_loader, test_loader = ...

trainer = create_supervised_trainer(model, optimizer, loss)

@trainer.on(Events.ITERATION_COMPLETED(every=50))

def log_training_loss_every_50_iterations():

print("{} / {} : {} - loss: {:.2f}"

.format(trainer.state.epoch, trainer.state.max_epochs, trainer.state.iteration, trainer.state.output))

@trainer.on(Events.EPOCH_STARTED(once=25))

def do_something_once_on_25_epoch():

# do something

def custom_event_filter(engine, event):

if event in [1, 2, 5, 10, 50, 100]:

return True

return False

@engine.on(Events.ITERATION_STARTED(event_filter=custom_event_filter))

def call_on_special_event(engine):

# do something on 1, 2, 5, 10, 50, 100 iterations

trainer.run(train_loader, max_epochs=100)

也可以自定义Events:

class CustomEvents(EventEnum):

"""

Custom events defined by user

"""

CUSTOM_STARTED = 'custom_started'

CUSTOM_COMPLETED = 'custom_completed'

engine.register_events(*CustomEvents)

可以同时对某个handler设置多个events:

events = Events.STARTED | Events.COMPLETED | Events.ITERATION_STARTED(every=3)

engine = ...

@engine.on(events)

def call_on_events(engine):

# do something

这些事件可用于附加任何处理程序,并使用触发fire_event()。

@engine.on(CustomEvents.CUSTOM_STARTED)

def call_on_custom_event(engine):

# do something

@engine.on(Events.STARTED)

def fire_custom_events(engine):

engine.fire_event(CustomEvents.CUSTOM_STARTED)

Handlers 函数的参数不一定非得是engine,不涉及可以空参,可以多个其他参数

也可以允许将事件过滤器传递给引擎:

engine = Engine()

# a) custom event filter

def custom_event_filter(engine, event):

if event in [1, 2, 5, 10, 50, 100]:

return True

return False

@engine.on(Events.ITERATION_STARTED(event_filter=custom_event_filter))

def call_on_special_event(engine):

# do something on 1, 2, 5, 10, 50, 100 iterations

# b) "every" event filter

@engine.on(Events.ITERATION_STARTED(every=10))

def call_every(engine):

# do something every 10th iteration

# c) "once" event filter

@engine.on(Events.ITERATION_STARTED(once=50))

def call_once(engine):

# do something on 50th iteration

5. 内置Handlers

库提供了一组内置处理程序,用于检查训练流水线,保存最佳模型,在没有改进的情况下停止训练,使用实验跟踪系统等。可以在以下两个模块中找到它们:

ignite.handlersignite.contrib.handlers

一些类可以简单地添加Engine为可调用函数。例如,

from ignite.handlers import TerminateOnNan

trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

另外还提供了attach()方法,咋程序执行中手动的添加handles给Engine

from ignite.contrib.handlers.tensorboard_logger import *

# Create a logger

tb_logger = TensorboardLogger(log_dir="experiments/tb_logs")

# Attach the logger to the trainer to log model's weights as a histogram after each epoch

tb_logger.attach(

trainer,

event_name=Events.EPOCH_COMPLETED,

log_handler=WeightsHistHandler(model)

)

6.State

State 是用来存储 Engine 的输出结果的,每一个Engine对象都有 State 属性

engine.state.seed: Seed to set at each data “epoch”.engine.state.epoch: Number of epochs the engine has completed. Initializated as 0 and the first epoch is 1.engine.state.iteration: Number of iterations the engine has completed. Initialized as 0 and the first iteration is 1.engine.state.max_epochs: Number of epochs to run for. Initializated as 1.engine.state.output: The output of the process_function defined for the Engine.etc

其他的可在技术文档里查找

在下面的代码中,engine.state.output 将存储批次损失。此输出用于打印每次迭代的损失。

def update(engine, batch):

x, y = batch

y_pred = model(inputs)

loss = loss_fn(y_pred, y)

optimizer.zero_grad()

loss.backward()

optimizer.step()

return loss.item()

def on_iteration_completed(engine):

iteration = engine.state.iteration

epoch = engine.state.epoch

loss = engine.state.output

print("Epoch: {}, Iteration: {}, Loss: {}".format(epoch, iteration, loss))

trainer.add_event_handler(Events.ITERATION_COMPLETED, on_iteration_completed)

在下面的代码中,engine.state.output将是已处理批次的损耗列表y_pred,y。如果要连接Accuracy到引擎,则需要output_transform来从engine.state.output获取y_pred和y

def update(engine, batch):

x, y = batch

y_pred = model(inputs)

loss = loss_fn(y_pred, y)

optimizer.zero_grad()

loss.backward()

optimizer.step()

return loss.item(), y_pred, y

trainer = Engine(update)

@trainer.on(Events.EPOCH_COMPLETED)

def print_loss(engine):

epoch = engine.state.epoch

loss = engine.state.output[0]

print ('Epoch {epoch}: train_loss = {loss}'.format(epoch=epoch, loss=loss))

accuracy = Accuracy(output_transform=lambda x: [x[1], x[2]])

accuracy.attach(trainer, 'acc')

trainer.run(data, max_epochs=10)

与上面类似,但是这次process_function的输出是处理后的批次的损耗字典y_pred,y,这是用户可以使用output_transform从engine.state.output获取y_pred和y的方式

def update(engine, batch):

x, y = batch

y_pred = model(inputs)

loss = loss_fn(y_pred, y)

optimizer.zero_grad()

loss.backward()

optimizer.step()

return {'loss': loss.item(),

'y_pred': y_pred,

'y': y}

trainer = Engine(update)

@trainer.on(Events.EPOCH_COMPLETED)

def print_loss(engine):

epoch = engine.state.epoch

loss = engine.state.output['loss']

print ('Epoch {epoch}: train_loss = {loss}'.format(epoch=epoch, loss=loss))

accuracy = Accuracy(output_transform=lambda x: [x['y_pred'], x['y']])

accuracy.attach(trainer, 'acc')

trainer.run(data, max_epochs=10)

优良作法是State还用作存储在更新或处理程序函数中创建的用户数据。例如,我们想将new_attribute保存为state:

def user_handler_function(engine):

engine.state.new_attribute = 12345

7. Metrics

库提供了各种机器学习任务的现成指标列表。支持两种计算指标的方式:1)在线和2)存储整个输出历史记录

指标可以附加到 Engine:

from ignite.metrics import Accuracy

accuracy = Accuracy()

accuracy.attach(evaluator, "accuracy")

state = evaluator.run(validation_data)

print("Result:", state.metrics)

# > {"accuracy": 0.12345}

windows注册表编辑器误删了怎么恢复
如何保存PDF文件▷➡️