コード例 #1
0
ファイル: train.py プロジェクト: YiKeYaTu/Experiments
from torch.utils.data import DataLoader
from constant import DEVICE, LEARNING_RATE, ITERATION_SIZE, WEIGHT_DECAY, TMP_ROOT
from utils.StatisticalValue import StatisticalValue
from utils.functions.status import print_training_status
from loss_functions.multi_angular_loss import multi_angular_loss
from env import iteration_writer
from torchvision import transforms
from os.path import join
import torch
import torchvision
import numpy as np
import os

trainloader = DataLoader(ImageNet(
    train=True,
    transform=transforms.Compose([transforms.ToTensor()]),
    target_transform=transforms.Compose([transforms.ToTensor()]),
),
                         batch_size=10,
                         shuffle=False,
                         num_workers=8)

model = ResNetMCC()
model.to(device=DEVICE)

criterion = torch.nn.MSELoss(reduction='sum')
# criterion = multi_angular_loss
optimizer = torch.optim.Adam(model.parameters(),
                             lr=LEARNING_RATE,
                             weight_decay=WEIGHT_DECAY)
# optimizer = torch.optim.SGD(
コード例 #2
0
from dataloaders.multi_color_constancy.ImageNet import ImageNet
from models.resnet.ResNetMCC import ResNetMCC
from torch.utils.data import DataLoader
from constant import DEVICE, TMP_ROOT
from utils.StatisticalValue import StatisticalValue
from loss_functions.multi_angular_loss import multi_angular_loss
from torchvision import transforms
import torch
import torchvision
import os
import time
from thop import profile

dataset = ImageNet(
    train=False,
    transform=transforms.Compose([transforms.ToTensor()]),
    target_transform=transforms.Compose([transforms.ToTensor()]),
)
testloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8)

model = ResNetMCC()
model.to(device=DEVICE)

macs, params = profile(model,
                       inputs=(torch.randn(1, 3, 224, 224).to(DEVICE), ))
print("Model's macs is %f, params is %f" % (macs, params))


def run():
    statistical_angular_errors = StatisticalValue()
    sub_dir = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
コード例 #3
0
ファイル: train.py プロジェクト: YiKeYaTu/Experiments
from utils.functions.status import print_training_status
from loss_functions.multi_angular_loss import multi_angular_loss
from env import iteration_writer
from torchvision import transforms
from os.path import join
import torch
import torchvision
import numpy as np
import os

trainloader = DataLoader(
    ImageNet(
        train=True,
        transform=transforms.Compose([
            transforms.ToTensor()
        ]),
        target_transform=transforms.Compose([
            transforms.ToTensor()
        ]),
    ),
    batch_size=10,
    shuffle=False,
    num_workers=8
)

model = ResNetMCC(layer_count=152)
model.to(device=DEVICE)

criterion = torch.nn.MSELoss(reduction='sum')
# criterion = multi_angular_loss
optimizer = torch.optim.Adam(