学习PyTorch
- DataLoader
torch.utils.data.DataLoader:
Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
- Flatten
torch.nn.Flatten(start_dim=1, end_dim=- 1)
- model.train() & model.eval()
在使用Pytorch
进行模型的训练和测试时,我们总能在训练部分的最前面看到model.train()
,在测试部分最前面看到model.eval()
。
Pytorch model.train()_长命百岁️的博客-CSDN博客_pytorch中model.train
主要是对
Batch Normalization
和Dropout
层有影响。因为这两层在训练和测试时进行的操作是不同的。
model.train()
设置模型为训练模式,即
• BatchNorm
层利用每个 batch 来统计
• Dropout
层激活
model.eval()
设置模型为评估/推理模式,即
• BatchNorm
layers use running statistics
•Dropout
层取消。
等效于 model.train(False)
。
参数量估计
学习PyTorch
https://cosmicdusty.cc/post/AI/LearnPyTorch/