学习PyTorch

  • DataLoader

torch.utils.data.DataLoader:
Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.

  • Flatten

torch.nn.Flatten(start_dim=1, end_dim=- 1)

  • model.train() & model.eval()

在使用Pytorch进行模型的训练和测试时,我们总能在训练部分的最前面看到model.train(),在测试部分最前面看到model.eval()
Pytorch model.train()_长命百岁️的博客-CSDN博客_pytorch中model.train

主要是对Batch NormalizationDropout 层有影响。因为这两层在训练和测试时进行的操作是不同的。

model.train()
设置模型为训练模式,即
BatchNorm 层利用每个 batch 来统计
Dropout 层激活

model.eval()
设置模型为评估/推理模式,即
BatchNorm layers use running statistics
Dropout 层取消。
等效于 model.train(False)

参数量估计

https://blog.csdn.net/qq_33952811/article/details/124276599


学习PyTorch
https://cosmicdusty.cc/post/AI/LearnPyTorch/
作者
Murphy
发布于
2022年9月7日
许可协议