Browse Source

Add 'test_infinite.py'

master
huaqimao 7 months ago
parent
commit
999dd90b97
1 changed files with 40 additions and 0 deletions
  1. +40
    -0
      test_infinite.py

+ 40
- 0
test_infinite.py View File

@ -0,0 +1,40 @@
import torch
import torch.nn as nn
import numpy as np
from tensorboardX import SummaryWriter
# 构建输入集
x = np.mat('0 0;'
'0 1;'
'1 0;'
'1 1')
x = torch.tensor(x).float()
y = np.mat('1;'
'0;'
'0;'
'1')
y = torch.tensor(y).float()
# 搭建网络
myNet = nn.Sequential(
nn.Linear(2, 10),
nn.ReLU(),
nn.Linear(10, 1),
nn.Sigmoid()
)
print(myNet)
# 设置优化器
optimzer = torch.optim.SGD(myNet.parameters(), lr=0.05)
loss_func = nn.MSELoss()
writer = SummaryWriter('/logs')
for epoch in range(5000):
out = myNet(x)
loss = loss_func(out, y) # 计算误差
optimzer.zero_grad() # 清除梯度
loss.backward()
optimzer.step()
print(loss)
writer.add_scalar("loss", loss, epoch)
print(myNet(x).data)

Loading…
Cancel
Save