Найти в Дзене

Как я сделал простую нейросеть для рисовки цифр...

Кто хочет попробовать запустить, вот код: import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
latent_dim = 100
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 784),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), 1
Первые попытки на маленьком кол эпох обучения...
Первые попытки на маленьком кол эпох обучения...
Результат на 300 эпохах!
Результат на 300 эпохах!

Кто хочет попробовать запустить, вот код:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)

latent_dim = 100

class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 784),
nn.Tanh()
)

def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), 1, 28, 28)
return img.to(device)

class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)

def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity.to(device)

def generate_images(generator, num_images=25, figsize=(5, 5)):
generator.eval()
with torch.no_grad():
z = torch.randn(num_images, latent_dim).to(device)
generated_images = generator(z).cpu().detach()
generator.train()

plt.figure(figsize=figsize)
for i in range(num_images):
plt.subplot(5, 5, i + 1)
plt.imshow(generated_images[i].view(28, 28), cmap='gray')
plt.axis('off')
plt.show()

generator = Generator().to(device)
generator.load_state_dict(torch.load('generator.pth', map_location=device))

discriminator = Discriminator().to(device)
discriminator.load_state_dict(torch.load('discriminator.pth', map_location=device))

generate_images(generator, num_images=25, figsize=(8, 8))

Ссылки на нужные файлы --- https://drive.google.com/file/d/1KR-ohJSRH6h6vmAdHfUv73vMb5G25r96/view?usp=drive_link, https://drive.google.com/file/d/18SejNRDfqQoXD8-zMzT7OnNzJlAcfyQu/view?usp=drive_link

Тгк ------- https://t.me/+UKpct0X5M6FmNzIy