from Utils import MnistLoadData
from Utils import CIFARLoadData
from Models.VAE_Model.Parser_args import parse_Arg
from Models.VAE_Model.VAE import vae
from torchvision.utils import save_image
from Utils import get_device

args = parse_Arg()
image_shape = (args.channels, args.image_size, args.image_size)
data_loader = CIFARLoadData(args.batch_size, True, True)

device = get_device()

model = vae(args, device)

for epoch in range(args.n_epochs):
    for i, data in enumerate(data_loader, 0):
        inputs, _ = data
        current_batch_size = inputs.size(0)
        inputs = inputs.view(current_batch_size, args.input_dim * args.channels).to(device)

        loss = model.learn(inputs)

        print("[Epoch %d/%d] [Batch %d/%d] [loss: %f]]" % (epoch + 1, args.n_epochs, i + 1, len(data_loader), loss))

        batches_done = epoch * len(data_loader) + i
        if batches_done % args.sample_interval == 0:
            output = model(inputs).data.reshape(args.batch_size, args.channels, args.image_size, args.image_size)
            save_image(output, "images/%d.png" % batches_done, nrow=args.nrow, normalize=True)

import torch
from Models.Resnet_Model.Resnet import ResNet
from Utils import MnistLoadData
from Models.Resnet_Model.Parser_args import parse_Arg
from Utils import get_device
from Utils import CIFARLoadData

args = parse_Arg()
train_loader = CIFARLoadData(args.batch_size, True, False)
test_loader = CIFARLoadData(args.batch_size, False, False)

device = get_device()

# 50 Resnet
model = ResNet(args.channels, args.layers).to(device)

# 학습 진행
for epoch in range(args.n_epochs):
    for i, data in enumerate(train_loader):
        inputs, labels = data[0].to(device), data[1].to(device)

        loss = model.Learn(inputs, labels)

        print("[Epoch %d/%d] [Batch %d/%d] [loss: %f]]" % (epoch + 1, args.n_epochs, i + 1, len(train_loader), loss))

# 평가
with torch.no_grad():
    model.eval()

    X_test = test_loader.test_data.view(len(test_loader), args.channels, 28, 28).float()
    Y_test = test_loader.test_labels