import torch
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt
import torch.nn as nn
from torch import optim
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
BATCH_SIZE = 16
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
model = ViT(patch_size=16,
img_dimension=(32, 32),
latent_size=1024,
num_heads=2,
num_classes=10)
LR = 1e-3
optimizer = optim.Adam(model.parameters(), lr = 1e-3, amsgrad = True)
loss_fn = nn.CrossEntropyLoss()
def train_1_epoch(train_loader):
loss_value = 0
cnt = 0
for (x, y) in tqdm(train_loader):
logits = model(x)
loss = loss_fn(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_value += loss.item()
cnt += 1
return loss_value/cnt
@torch.no_grad()
def eval(testloader):
model.eval()
loss_value = 0
cnt = 0
num_correct = 0
num_samples = 0
for (x, y) in tqdm(testloader):
logits = model(x)
loss = loss_fn(logits, y)
loss_value += loss.item()
pred = logits.argmax(1)
num_correct += len(pred[pred==y])
num_samples += len(y)
cnt += 1
model.train()
return loss_value/cnt, num_correct/num_samples
def train(epochs):
for epoch in range(epochs):
train_loss = train_1_epoch(trainloader)
val_loss, val_acc = eval(testloader)
print(f"Epoch: {epoch} Train Loss: {train_loss} Validation Loss: {val_loss} Val Acc: {val_acc}")
train(10)