Beispiel #1
0
from resnet import ResNet18
import torch
import torch.nn as nn

myNet = ResNet18(1000)
myNet.load_state_dict(torch.load("resnet18-5c106cde.pth"))
Beispiel #2
0
def main():
    batchsz = 32

    cifar_train = datasets.CIFAR10('cifar',
                                   True,
                                   transform=transforms.Compose([
                                       transforms.Resize((32, 32)),
                                       transforms.ToTensor()
                                   ]),
                                   download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)

    cifar_test = datasets.CIFAR10('cifar',
                                  False,
                                  transform=transforms.Compose([
                                      transforms.Resize((32, 32)),
                                      transforms.ToTensor()
                                  ]),
                                  download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)

    x, label = iter(cifar_train).next()
    print('x:', x.shape, 'label:', label.shape)

    device = torch.device('cuda')
    # model = Lenet5().to(device)
    model = ResNet18().to(device)

    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    for epoch in range(1000):

        model.train()
        for batchidx, (x, label) in enumerate(cifar_train):
            # [b, 3, 32, 32]
            # [b]
            x, label = x.to(device), label.to(device)

            logits = model(x)
            # logits: [b, 10]
            # label:  [b]
            # loss: tensor scalar
            loss = criteon(logits, label)

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        #
        print(epoch, 'loss:', loss.item())

        model.eval()
        with torch.no_grad():
            # test
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                # [b, 3, 32, 32]
                # [b]
                x, label = x.to(device), label.to(device)

                # [b, 10]
                logits = model(x)
                # [b]
                pred = logits.argmax(dim=1)
                # [b] vs [b] => scalar tensor
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
                # print(correct)

            acc = total_correct / total_num
            print(epoch, 'acc:', acc)
Beispiel #3
0
    num_workers=2)  #生成一个个batch进行批训练,组成batch的时候顺序打乱取

testset = torchvision.datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=False,
                                       transform=transform_test)
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=100,
                                         shuffle=False,
                                         num_workers=2)
# Cifar-10的标签
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
           'ship', 'truck')

# 模型定义-ResNet
net = ResNet18().to(device)

# 定义损失函数和优化方式
criterion = nn.CrossEntropyLoss()  #损失函数为交叉熵,多用于多分类问题
optimizer = optim.SGD(
    net.parameters(), lr=LR, momentum=0.9,
    weight_decay=5e-4)  #优化方式为mini-batch momentum-SGD,并采用L2正则化(权重衰减)

# 训练
if __name__ == "__main__":
    best_acc = 85  #2 初始化best test accuracy
    print("Start Training, Resnet-18!")  # 定义遍历数据集的次数
    with open("acc.txt", "w") as f:
        with open("log.txt", "w") as f2:
            for epoch in range(pre_epoch, EPOCH):
                print('\nEpoch: %d' % (epoch + 1))
torch.backends.cudnn.benchmark = True

trainset = DatasetFile(cfg.root, args.train_file, transform=transform_train)
trainloader = DataLoader(trainset,
                         batch_size=args.bs,
                         shuffle=True,
                         num_workers=args.workers)

testset = DatasetFile(cfg.root, args.valid_file, transform=transform_test)
testloader = DataLoader(testset,
                        batch_size=args.bs,
                        shuffle=False,
                        num_workers=args.workers)

net = ResNet18(dim=cfg.nb_class, r=args.r, c=args.init_channels)
if args.pth is not None:
    net.load_state_dict(torch.load(args.pth))
net.cuda()
net.train()
opt = torch.optim.SGD(net.parameters(),
                      lr=args.lr,
                      momentum=0.9,
                      weight_decay=args.wd)
ce_loss = nn.CrossEntropyLoss()

epoch = args.epoch
best_acc = 0.0

for i in range(epoch):
    if i + 1 in [int(epoch * 0.5), int(epoch * 0.8)]:
Beispiel #5
0
args = parser.parse_args()
logger = LogSaver(args.logdir)
logger.save(str(args), 'args')

# data
dataset = CIFAR10(args.datadir)
logger.save(str(dataset), 'dataset')
test_list = dataset.getTestList(500, True)

# model
start_iter = 0
lr = args.lr
if args.model == 'resnet':
    from resnet import ResNet18
    model = ResNet18().cuda()
elif args.model == 'vgg':
    from vgg import vgg11
    model = vgg11().cuda()
else:
    raise NotImplementedError()
criterion = CEwithMask
optimizer = torch.optim.SGD(model.parameters(),
                            lr=lr,
                            momentum=args.momentum,
                            weight_decay=args.weightdecay)
if args.resume:
    checkpoint = torch.load(args.resume)
    start_iter = checkpoint['iter'] + 1
    lr = checkpoint['lr']
    model.load_state_dict(checkpoint['model'])
Beispiel #6
0
    def load_network(self):
        logger.info("Start loading network, loss function and optimizer")

        # Load a network
        # self.net = VGG('VGG11')
        self.net = ResNet18()

        # Move network to GPU if needed
        if self.args.gpu:
            self.net.to('cuda')

        # Define the loss function and the optimizer
        self.criterion = nn.CrossEntropyLoss()

        if self.args.optimizer.lower() == 'adadelta':
            logger.info("Selected adadelta as optimizer")
            self.optimizer = optim.Adadelta(self.net.parameters(),
                                            lr=1.0,
                                            rho=0.9,
                                            eps=1e-06,
                                            weight_decay=0)
        elif self.args.optimizer.lower() == 'adagrad':
            logger.info("Selected adagrad as optimizer")
            self.optimizer = optim.Adagrad(self.net.parameters(),
                                           lr=0.01,
                                           lr_decay=0,
                                           weight_decay=0,
                                           initial_accumulator_value=0)
        elif self.args.optimizer.lower() == 'adam':
            logger.info("Selected adam as optimizer")
            self.optimizer = optim.Adam(self.net.parameters(),
                                        lr=0.001,
                                        betas=(0.9, 0.999),
                                        eps=1e-08,
                                        weight_decay=0,
                                        amsgrad=False)
        elif self.args.optimizer.lower() == 'sparseadam':
            logger.info("Selected sparseadam as optimizer")
            self.optimizer = optim.SparseAdam(self.net.parameters(),
                                              lr=0.001,
                                              betas=(0.9, 0.999),
                                              eps=1e-08)
        elif self.args.optimizer.lower() == 'adamax':
            logger.info("Selected adamax as optimizer")
            self.optimizer = optim.Adamax(self.net.parameters(),
                                          lr=0.002,
                                          betas=(0.9, 0.999),
                                          eps=1e-08,
                                          weight_decay=0)
        elif self.args.optimizer.lower() == 'asgd':
            logger.info("Selected asgd as optimizer")
            self.optimizer = optim.ASGD(self.net.parameters(),
                                        lr=0.01,
                                        lambd=0.0001,
                                        alpha=0.75,
                                        t0=1000000.0,
                                        weight_decay=0)
        elif self.args.optimizer.lower() == 'lbfgs':
            logger.info("Selected lbfgs as optimizer")
            self.optimizer = optim.LBFGS(self.net.parameters(),
                                         lr=1,
                                         max_iter=20,
                                         max_eval=None,
                                         tolerance_grad=1e-05,
                                         tolerance_change=1e-09,
                                         history_size=100,
                                         line_search_fn=None)
        elif self.args.optimizer.lower() == 'rmsprop':
            logger.info("Selected rmsprop as optimizer")
            self.optimizer = optim.RMSprop(self.net.parameters(),
                                           lr=0.01,
                                           alpha=0.99,
                                           eps=1e-08,
                                           weight_decay=0,
                                           momentum=0,
                                           centered=False)
        elif self.args.optimizer.lower() == 'rprop':
            logger.info("Selected rprop as optimizer")
            self.optimizer = optim.Rprop(self.net.parameters(),
                                         lr=0.01,
                                         etas=(0.5, 1.2),
                                         step_sizes=(1e-06, 50))
        elif self.args.optimizer.lower() == 'sgd':
            logger.info("Selected sgd as optimizer")
            self.optimizer = optim.SGD(self.net.parameters(),
                                       lr=0.001,
                                       momentum=0,
                                       dampening=0,
                                       weight_decay=0,
                                       nesterov=False)
        else:
            logger.info("Unknown optimizer given, SGD is chosen instead.")
            self.optimizer = optim.SGD(self.net.parameters(),
                                       lr=0.001,
                                       momentum=0.9)

        logger.info(
            "Loading network, loss function and %s optimizer was successful",
            self.args.optimizer)
Beispiel #7
0
def main():

    batch_size = 32

    cifar_train = datasets.CIFAR10('cifar',True,transform=transforms.Compose([
            transforms.Resize((32,32)),
            transforms.ToTensor()
        ]),download=True)

    cifar_train = DataLoader(cifar_train,batch_size=batch_size,shuffle=True)

    cifar_test = datasets.CIFAR10('cifar',False,transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor()
    ]),download=True)

    cifar_test = DataLoader(cifar_test,batch_size=batch_size,shuffle=True)


    x,label = iter(cifar_train).next()
    print('x:',x.shape,'label:',label.shape)

    device = torch.device('cuda')
    # model = LeNet5().to(device)
    model = ResNet18().to(device)

    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(),lr=1e-3)
    print(model)

    for epoch in range(1000):

        model.train()
        for batchidx,(x,label) in enumerate(cifar_train):

            #[b,3,32,32]
            #[b]
            x,label = x.to(device),label.to(device)

            logits = model(x)# 针对神经网络的输入和输出
            #logits:[b,10]
            #label:[b]
            #loss:tensor scalar
            loss = criteon(logits,label)# 运用某种标准去衡量误差

            #backprop
            optimizer.zero_grad()# 将一批训练集输入之前的梯度初始化为0
            loss.backward()# 计算反向传播的梯度
            optimizer.step()#  根据梯度进行更新

        #
        print(epoch,loss.item())

        model.eval()#例如dropout什么的都恢复
        with torch.no_grad():#不需要动态图
            # test
            totol_correct = 0
            totol_num = 0
            for x,label in cifar_test:
                #[b,3,32,32]
                #[b]
                x, label = x.to(device) , label.to(device)
                #[b,10]
                logits = model(x)
                #[b]
                pred = logits.argmax(dim=1)
                totol_correct += torch.eq(pred,label).float().sum().item()
                totol_num += x.size(0)

            acc = totol_correct/totol_num
            print(epoch,acc)
Beispiel #8
0
            torch.load(
                'validation/autoencoder_checkpoints/final_classification_model-stimulus.pkl'
            ))
        classifier_id.load_state_dict(
            torch.load(
                'validation/autoencoder_checkpoints/final_classification_model-id.pkl'
            ))

        autoencoder_alcoholism.eval()
        autoencoder_stimulus.eval()
        autoencoder_id.eval()

    #### load ResNet-based model ####
    else:  #the model is specified by the arg --classifier
        if opt.classifier == 'ResNet18':
            classifier_alcoholism = ResNet18(num_classes_alc)
            classifier_stimulus = ResNet18(num_classes_stimulus)
            classifier_id = ResNet18(num_classes_id)
        elif opt.classifier == 'ResNet34':
            classifier_alcoholism = ResNet34(num_classes_alc)
            classifier_stimulus = ResNet34(num_classes_stimulus)
            classifier_id = ResNet34(num_classes_id)
        elif opt.classifier == 'ResNet50':
            classifier_alcoholism = ResNet50(num_classes_alc)
            classifier_stimulus = ResNet50(num_classes_stimulus)
            classifier_id = ResNet50(num_classes_id)

        classifier_alcoholism = classifier_alcoholism.to(device)
        classifier_stimulus = classifier_stimulus.to(device)
        classifier_id = classifier_id.to(device)
        if device == 'cuda':
     
 elif cfg['dataset'] == 'aircraft':
     
     transform = torchvision.transforms.Compose([
         torchvision.transforms.Resize([224,224], interpolation=PIL.Image.BICUBIC),
         torchvision.transforms.ToTensor(),
         torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
     ])
     
     trainset = torchvision.datasets.ImageFolder(root='../data/aircraft-100/train/', transform=transform)
     testset = torchvision.datasets.ImageFolder(root='../data/aircraft-100/test/', transform=transform)
 
         
 if cfg['dataset'] in ['cifar10', 'cifar100']:
     if cfg['network'] == 'resnet18':
         model = ResNet18(cfg['dim'], 0, 0, 'dummy').cuda()
 else:
     print("Cub, Cars, Tiny ImageNet and Dogs model init.")
     
     model = torchvision.models.resnet50(pretrained=False)
     in_ftr  = model.fc.in_features
     out_ftr = cfg['dim']
     model.fc = nn.Linear(in_ftr,out_ftr,bias=True)
     
     if cfg['dataset'] == 'imagenet':
         model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
         
 model.load_state_dict(a['state_dict'])
 model = model.cuda()
 model.eval()
 
Beispiel #10
0
        trainset.targets)[rand_perm[:len(trainset)]]).tolist()

torch.manual_seed(args.seed)
trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          num_workers=2)
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=2)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

if args.model == 'resnet18':
    net = ResNet18(num_classes=num_classes, linear_base=linear_base)
elif args.model == 'resnet50':
    net = ResNet50(num_classes=num_classes, linear_base=linear_base)
elif args.model == 'convnet':
    net = ConvNet(num_classes=num_classes)

best_acc = 0
if args.saved_model != '':
    checkpoint = torch.load(args.saved_model)
    net.load_state_dict(checkpoint['net'], strict=False)
    #best_acc = checkpoint['acc']

net = net.to(device)
criterion = nn.CrossEntropyLoss()

Beispiel #11
0
def main():
    batchsz = 32
    # 一张一张的加载
    cifar_train = datasets.CIFAR10('cifar',
                                   True,
                                   transform=transforms.Compose([
                                       transforms.Resize((32, 32)),
                                       transforms.ToTensor(),
                                       transforms.Normalize(
                                           mean=[0.485, 0.456, 0.406],
                                           std=[0.229, 0.224, 0.225])
                                   ]),
                                   download=True)
    # dataloader 一批一批的加载
    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)

    cifar_test = datasets.CIFAR10('cifar',
                                  False,
                                  transform=transforms.Compose([
                                      transforms.Resize((32, 32)),
                                      transforms.ToTensor(),
                                      transforms.Normalize(
                                          mean=[0.485, 0.456, 0.406],
                                          std=[0.229, 0.224, 0.225])
                                  ]),
                                  download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)

    x, label = iter(cifar_train).next()  # 得到一个batch
    print('x:', x.shape, 'label:', label.shape)

    device = torch.device('cuda')
    # model = Lenet5().to(device)
    model = ResNet18().to(device)
    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    # 打印出类的结构
    print(model)

    for epoch in range(1000):
        # 考虑到dropout,normalize这些在train和test时是不一样的
        # 对模型变成train模式
        model.train()
        for batchidx, (x, label) in enumerate(cifar_train):
            # [b, 3, 32, 32]
            # [b]
            x, label = x.to(device), label.to(device)

            logits = model(x)
            # logits: [b, 10]
            # label: [b]
            loss = criteon(logits, label)

            # backprob
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # loss: tensor scalar
        print("epoch:", epoch, "loss:", loss.item())

        model.eval()
        # 表明下面的这些代码都不需要梯度反向传播
        with torch.no_grad():
            # test
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                # [b, 3, 32, 32]
                # [b]
                x, label = x.to(device), label.to(device)

                # [b, 10]
                logits = model(x)
                # [b]
                pred = logits.argmax(dim=1)
                # [b] vs [b] => scalar tensor
                total_correct += torch.eq(pred, label).float().sum().item()
                total_num += x.size(0)

            acc = total_correct / total_num
            print("epoch:", epoch, "test acc:", acc)
Beispiel #12
0
        epoch + 1, correct, correct_rate))

    tb_logger.add_scalar("val_accuracy", correct_rate.item(), epoch)

    torch.save({
        'params': net.state_dict(),
        'epoch': epoch + 1
    }, f'./model/{model_name}.pth')


if __name__ == "__main__":
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_name = f'{args.schedule}{args.epoch}'
    tsa_enable = args.tsa
    total_epoch = int(args.epoch)
    net = ResNet18().to(DEVICE)
    if not os.path.exists(f'./model'):
        os.mkdir('model')
    if os.path.exists(f'./model/{model_name}.pth'):
        load = torch.load(f'./model/{model_name}.pth')
        current_epoch = load['epoch']
        net.load_state_dict(load['params'])
        sup_dataloader, val_dataloader, unsup_dataloader = dataload(
            1, args.supnum)
    else:
        current_epoch = 0
        sup_dataloader, val_dataloader, unsup_dataloader = dataload(
            0, args.supnum)
    total_step = total_epoch * len(unsup_dataloader)
    sup_criterion = nn.CrossEntropyLoss(
        reduction='none') if tsa_enable == '1' else nn.CrossEntropyLoss()
Beispiel #13
0
    plt.title('dataset histogram')
    plt.xlabel('class_id')
    plt.ylabel('class_num')
    #图片抽样查看
    fig2 = plt.figure()
    images = dataset.data[:20]
    for i in np.arange(1, 21):
        plt.subplot(5, 4, i)
        plt.text(10, 10, '{}'.format(targets[i - 1]), fontsize=20, color='g')
        plt.imshow(images[i - 1])
    fig2.suptitle('Images')
    plt.show()


if __name__ == '__main__':
    net = ResNet18()
    writer = SummaryWriter(comment="myresnet")
    #绘制网络框图
    with writer:
        writer.add_graph(net, (torch.rand(1, 3, 32, 32), ))
    optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9)
    lr_sch = torch.optim.lr_scheduler.StepLR(optimizer,
                                             30,
                                             gamma=0.1,
                                             last_epoch=-1)
    loss_func = nn.CrossEntropyLoss()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net.to(device)
    # 随机获取训练图片
    for epoch in range(EPOCH):
Beispiel #14
0
import cv2

from test_transform import *
from resnet import ResNet18

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

WEIGHT_PATH = './weights/model_keypoints_68pts_iter_450.pt'  #change the weights path
# IMAGE_PATH = '../test_image' #change the image path

net = ResNet18(136).to(device)
net.load_state_dict(torch.load(WEIGHT_PATH))
net.eval()


def detect(IMAGE_PATH):
    origin_image = cv2.imread(IMAGE_PATH)

    image = cv2.resize(origin_image, (224, 224))
    origin_image = image.copy()
    image = Normalize(image)
    image = ToTensor(image)
    image = image.unsqueeze(0)

    with torch.no_grad():

        if (torch.cuda.is_available()):
            image = image.type(torch.cuda.FloatTensor)
            image.to(device)
Beispiel #15
0
    # Make sure file lists are complete
    chs = ("g", "r", "i")
    valid_set = [
        (fi, ch)
        for fi, ch in zip(np.repeat(valid_set, len(chs)), itertools.cycle(chs))
        if isfile("./data/gals/{}-{}.fits".format(fi, ch))
        and isfile("./data/sbs_gri_30.0/{}_{}.txt".format(fi, ch))
    ]
    training_set = [(fi, ch) for fi, ch in zip(
        np.repeat(training_set, len(chs)), itertools.cycle(chs))
                    if isfile("./data/gals/{}-{}.fits".format(fi, ch))
                    and isfile("./data/sbs_gri_30.0/{}_{}.txt".format(fi, ch))]

    # Init Pix2Prof and load checkpoint if asked
    encoder = ResNet18(num_classes=args.encoding_len).to(cuda)
    decoder = GRUNet(input_dim=1,
                     hidden_dim=args.encoding_len,
                     output_dim=1,
                     n_layers=3).to(cuda)
    criterion = nn.MSELoss()
    encoder_op = optim.Adam(encoder.parameters(), lr=0.0002)
    decoder_op = optim.Adam(decoder.parameters(), lr=0.0002)

    if args.checkpoint is not None:
        checkpoint = torch.load(args.checkpoint)
        encoder.load_state_dict(checkpoint["encoder"])
        decoder.load_state_dict(checkpoint["decoder"])
        decoder_op.load_state_dict(checkpoint["decoder_op"])
        encoder_op.load_state_dict(checkpoint["encoder_op"])
        chk_epoch = checkpoint["epoch"]
Beispiel #16
0
    # load data
    x_train = joblib.load('data/x_data.jl')
    y_train = joblib.load('data/y_data.jl')
    x_val = joblib.load('data/x_test.jl')
    y_val = joblib.load('data/y_test.jl')
    train_loader = get_loader(x_train, y_train, batch_size=args.batch_size, num_workers=args.num_workers,
                              transforms=train_transform, shuffle=True)
    val_loader = get_loader(x_val, y_val, batch_size=args.batch_size, num_workers=args.num_workers,
                            transforms=val_transform, shuffle=False)

    print(f'# Train Samples: {len(x_train)} | # Val Samples: {len(x_val)}')

    # define model
    if args.cnn_baseline:
        model = ResNet18().to(device)
    else:
        model = ViT(
            image_size=32,
            patch_size=args.patch_size,
            num_classes=10,
            dim=args.dim,
            depth=args.depth,
            heads=args.heads,
            mlp_dim_mul=args.mlp_dim_mul,
            dropout=args.dropout,
            emb_dropout=args.emb_dropout,
            rel_pos=args.rel_pos,
            rel_pos_mul=args.rel_pos_mul,
            n_out_convs=args.n_out_convs,
            squeeze_conv=args.squeeze_conv,
            adv_batch.requires_grad = True

            output = model(adv_batch)
            LL_output = output.min(1)[1]

            loss = loss_function(output, LL_output)

            grad = torch.autograd.grad(loss, adv_batch)[0]

            adv_batch.requires_grad = False
            adv_batch -= self.alpha * grad.sign()

            diff = adv_batch - batch
            diff = torch.clamp(diff, min=-self.epsilon, max=self.epsilon)
            adv_batch = batch + diff

            adv_batch = torch.clamp(adv_batch, min=0, max=255)
        return adv_batch


if __name__ == '__main__':
    model = ResNet18()
    loss_function = nn.CrossEntropyLoss()
    batch = torch.randn((8, 3, 28, 28))
    label = torch.randint(low=0, high=9, size=(8, ))

    # atk = FGSM(epsilon=4, device='cpu')
    # atk = IFGSM(alpha=4, epsilon=30, n_iteration=10, device='cpu')
    atk = LLFGSM(alpha=4, epsilon=30, n_iteration=10, device='cpu')

    print(atk(batch, label, model, loss_function).shape)
)

test_loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=1,
                                          num_workers=16,
                                          pin_memory=True,
                                          sampler=test_sampler)

# ####### model for experiments ##############
# model=Network()
#
# model = my_resnt18(classes)
# model = Lenet5(classes)
# model = FashionCNN(classes)

model = ResNet18(classes)

# optimizer=torch.optim.SGD(model.parameters(), lr=0.03, weight_decay= 1e-6, momentum = 0.87,nesterov = True)
# optimizer=torch.optim.SGD(model.parameters(),lr=0.003,)

optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)

criterion = torch.nn.CrossEntropyLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if pre_train:
    model.load_state_dict(torch.load(weights_path))

model.to(device)
criterion.to(device)