def main(args):
    torch.device("cuda" if torch.cuda.is_available() else "cpu")
    mode = args.mode
    if mode == 'train':
        epochs = args.epochs
        evaluate_steps = args.evaluate_steps
        save_steps = args.save_steps
        save_dir = args.save_dir
        train(model, train_dataloader, epochs, evaluate_steps, save_steps,
              save_dir)
    elif mode == 'evaluate':
        model_path = args.model_path
        if model_path is None:
            logger.error('model_path is not set!')
            return
        model.load_state_dict(torch.load(model_path))
        model.eval()
        loss, acc = evaluate(model)
        logger.info("loss=%f, accuracy=%f", loss, acc)
    else:
        logger.error('wrong mode: %s', mode)
Esempio n. 2
0
    transforms = Compose([Resize(resize_dim),ToTensor(),normalize])
convert_to = 'RGB'

if args.input_type == 'dicom':
    dataset = iap.DicomSegment(args.img_path, transforms, convert_to)
elif args.input_type == 'png' and args.non_montgomery:
    dataset = iap.LungTest(args.img_path, transforms, convert_to)
elif args.input_type == 'png':
    dataset = iap.lungSegmentDataset(
        os.path.join(args.img_path, "CXR_png"),
        os.path.join(args.img_path, "ManualMask/leftMask/"),
        os.path.join(args.img_path, "ManualMask/rightMask/"),
        imagetransform=transforms,
        labeltransform=Compose([Resize((224, 224)),ToTensor()]),
        convert_to='RGB',
    )
dataloader = torch.utils.data.DataLoader(dataset,batch_size=1,shuffle=False)

model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(args.resume_from))
show = iap.visualize(dataset)

with torch.no_grad():
    for i, sample in enumerate(dataloader):
        img = torch.autograd.Variable(sample['image']).cuda()
        mask = model(img)
        if not args.non_montgomery:
            show.ImageWithGround(i,True,True,save=True)

        show.ImageWithMask(i, sample['filename'][0], mask.squeeze().cpu().numpy(), True, True, save=True)
Esempio n. 3
0
        pred = [sum(pred) / len(pred)] * len(pred)
    if not compute_metric:
        return pred
    pred = np.array(pred, dtype=np.float64)
    ce = log_loss(gt, pred)
    prauc = compute_prauc(pred, gt)
    rce = compute_rce(pred, gt)
    return ce, prauc, rce


if __name__ == '__main__':
    if not os.path.exists('results'):
        os.mkdir('results')
    checkpoint = torch.load(
        os.path.join(checkpoints_dir, model_name + '_best.pt'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    if make_prediction:
        data = pd.read_csv(test_file,
                           sep='\x01',
                           header=None,
                           names=all_features,
                           encoding='utf-8')
        pred = test(model, compute_metric=False)
        print('prediction finished')
        pred = pd.concat([
            data[['tweet_id', 'engaging_user_id']],
            pd.DataFrame({'prediction': pred})
        ], 1)
        pred.to_csv(os.path.join('results', arg.label + '_prediction.csv'),
                    header=False,
Esempio n. 4
0
if args.model == 'resnet':
    model = model.segmentNetwork().cuda()
    resize_dim = (400, 400)
    convert_to = 'L'
elif args.model == 'unet11':
    model = unet11(out_filters=3).cuda()
    resize_dim = (224, 224)
    convert_to = 'RGB'
elif args.model == 'unet16':
    model = unet16(out_filters=3).cuda()
    resize_dim = (224, 224)
    convert_to = 'RGB'

model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load("unet16_100.pth"))

if args.no_normalize:
    transforms = Compose([Resize(resize_dim),ToTensor()])
else:
    transforms = Compose([Resize(resize_dim),ToTensor(),normalize])
convert_to = 'RGB'

source_path = '/home/enzo/data/Cohen/'#CORDA-dataset-v3/'
#mask_path = '/home/enzo/data/CORDA-dataset-v4-masks/'
target_path = '/home/enzo/data/Cohen-masks/'#'/home/enzo/data/CORDA-dataset-v3-masked/'

for this_folder in ["data"]:#["RX--COVID-", "RX--COVID+", "RX+-COVID-", "RX+-COVID+"]:
	dataset = iap.LungTest(source_path+this_folder, transforms, convert_to)
	dataloader = torch.utils.data.DataLoader(dataset,batch_size=1,shuffle=False)
MODELPATH = '/home/dxtien/dxtien_research/nmduy/pytorch-lung-segmentation/lung_segmentation/unet16_100.pth'

normalize = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
model = unet16(out_filters=3).cuda()
resize_dim = (224, 224)
convert_to = 'RGB'

transforms = Compose([Resize(resize_dim), ToTensor(), normalize])
convert_to = 'RGB'

dataset = iap.MyLungTest(IMGPATH, IMGPATHTXT, transforms, convert_to)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)

#model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(MODELPATH))
#show = iap.visualize(dataset)

with torch.no_grad():
    for i, sample in enumerate(dataloader):
        img = torch.autograd.Variable(sample['image']).cuda()
        mask = model(img)
        # if not args.non_montgomery:
        #     show.ImageWithGround(i,True,True,save=True)

        # show.ImageWithMask(i, sample['filename'][0], mask.squeeze().cpu().numpy(), True, True, save=True)
        mask_np = mask.squeeze().cpu().numpy()
        filename = sample['filename']
        filename = filename.split('/')[-1]
        filename = filename[:-4]
        save_mask(mask_np, OUTDIR, filename=filename + '_mask.png')
Esempio n. 6
0
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

transform = transforms.Compose([
                 transforms.ToTensor(),
            ])

cfg_params = utils.read_data_cfg(args.cfgfile)

train_txt = cfg_params['train']
test_txt = cfg_params['test']
backup_dir = cfg_params['backup']

if args.load_model is not None:
    print('Loading model from %s.' % args.load_model)
    model = models.model.UNet(args.im_size, args.kernel_size)
    model.load_state_dict(torch.load(args.load_model))
elif args.test:
    print('Missing model file for evaluating test set.')
    exit()
else:
    model = models.model.UNet(args.im_size, args.kernel_size)

# Datasets and dataloaders.
if not args.test:
    train_dataset = IGVCDataset(train_txt, im_size=args.im_size, split='train', transform=transform, val_samples=args.val_samples)
    val_dataset = IGVCDataset(train_txt, im_size=args.im_size, split='val', transform=transform, val_samples=args.val_samples)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
    
    # Optmizer
    lr = args.lr
def train(model, criterion, converter, device, pretrain=False):
    imgdir_list = []
    label_list = []

    rootdir = '/home/chen-ubuntu/Desktop/checks_dataset/new_train_stamp_crop/'

    for mode in ['new_crop', 'new_crop2', 'new_crop3']:
        imgfile_dir = os.path.join(rootdir, mode)
        imgs_files = sorted(os.listdir(imgfile_dir))

        for img_file in imgs_files:
            imgs_dir = os.path.join(imgfile_dir, img_file)
            imgs = sorted(os.listdir(imgs_dir))
            for img in imgs:
                img_dir = os.path.join(imgs_dir, img)
                imgdir_list.append(img_dir)
                label_list.append(img[:-6])

    if len(imgdir_list) != len(label_list):
        print('dataset is wrong!')

    np.random.seed(10)
    state = np.random.get_state()
    np.random.shuffle(imgdir_list)
    np.random.set_state(state)
    np.random.shuffle(label_list)

    segment = len(imgdir_list) // 10
    train_imgdirs = imgdir_list[:segment * 9]
    train_labels = label_list[:segment * 9]
    val_imgdirs = imgdir_list[segment * 9:]
    val_labels = label_list[segment * 9:]

    print('trainset: ', len(train_imgdirs))
    print('validset: ', len(val_labels))

    trainset = dataset_seal.BaseDataset(train_imgdirs,
                                        train_labels,
                                        transform=dataset_seal.img_enhancer,
                                        _type='seal')
    validset = dataset_seal.BaseDataset(val_imgdirs,
                                        val_labels,
                                        transform=dataset_seal.img_padder,
                                        _type='seal')

    print('Device:', device)
    model = model.to(device)

    if pretrain:
        print("Using pretrained model")
        '''
        state_dict = torch.load("/home/chen-ubuntu/Desktop/checks_dataset/pths/crnn_pertrain.pth", map_location=device)

        cnn_modules = {}
        rnn_modules = {}
        for module in state_dict:
            if module.split('.')[1] == 'FeatureExtraction':
                key = module.replace("module.FeatureExtraction.", "")
                cnn_modules[key] = state_dict[module]
            elif module.split('.')[1] == 'SequenceModeling':
                key = module.replace("module.SequenceModeling.", "")
                rnn_modules[key] = state_dict[module]

        model.cnn.load_state_dict(cnn_modules)
        model.rnn.load_state_dict(rnn_modules)
        '''
    model.load_state_dict(
        torch.load(
            '/home/chen-ubuntu/Desktop/checks_dataset/pths/seal_lr3_bat256_aug_epoch15_acc0.704862.pth'
        ))

    dataloader = DataLoader(trainset,
                            batch_size=64,
                            shuffle=True,
                            num_workers=4,
                            drop_last=False)
    '''
    lr = 1e-3
    params = model.parameters()
    optimizer = optim.Adam(params, lr)
    optimizer.zero_grad()
    batch_cnt = 0
    for epoch in range(config.epochs):
        epoch_loss = 0
        model.train()
        train_acc = 0
        train_acc_cnt = 0
        
        for i, (img, label, _) in enumerate(dataloader):
            n_correct = 0
            batch_cnt += 1
            train_acc_cnt += 1
            img = img.to(device)
            text, length = converter.encode(label)
            preds = model(img)
            preds_size = torch.IntTensor([preds.size(0)] * img.size(0))
            preds = preds.to('cpu')
            loss = criterion(preds, text, preds_size, length)

            _, preds = preds.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            sim_preds = converter.decode(preds.data, preds_size.data, raw=False)

            list1 = [x for x in label]
            for pred, target in zip(sim_preds, list1):
                if pred == target:
                    n_correct += 1

            # loss.backward()
            # optimizer.step()
            # model.zero_grad()

            loss.backward()
            if (i + 1) % 4:
                optimizer.step()
                optimizer.zero_grad()

            epoch_loss += loss.item()
            train_acc += n_correct / len(list1)

            if (i + 1) % 4 == 0:
                print("epoch: {:<3d}, batch: {:<3d},  batch loss: {:4f}, epoch loss: {:4f}, acc: {}". \
                      format(epoch, i, loss.item(), epoch_loss, n_correct / len(list1)))
                writer.add_scalar('data/train_loss', loss.item(), batch_cnt)
                writer.add_scalar('data/train_acc', n_correct / len(list1), batch_cnt)
    '''
    #print('train_average_acc is: {:.3f}'.format(train_acc / train_acc_cnt))
    acc, valid_loss = valid(model, criterion, converter, device, validset)
    '''
Esempio n. 8
0
def train(model,
          criterion,
          converter,
          device,
          train_datasets,
          pretrain=False):  #valid_datasets=None, pretrain=False):
    print('Device:', device)
    model = model.to(device)

    if pretrain:
        print("Using pretrained model")
        '''
        state_dict = torch.load("/home/chen-ubuntu/Desktop/checks_dataset/pths/crnn_pertrain.pth", map_location=device)

        cnn_modules = {}
        rnn_modules = {}
        for module in state_dict:
            if module.split('.')[1] == 'FeatureExtraction':
                key = module.replace("module.FeatureExtraction.", "")
                cnn_modules[key] = state_dict[module]
            elif module.split('.')[1] == 'SequenceModeling':
                key = module.replace("module.SequenceModeling.", "")
                rnn_modules[key] = state_dict[module]

        model.cnn.load_state_dict(cnn_modules)
        model.rnn.load_state_dict(rnn_modules)
        '''
    model.load_state_dict(
        torch.load(
            '/home/chen-ubuntu/Desktop/checks_dataset/pths/seal_lr3_bat256_aug_epoch15_acc0.704862.pth'
        ))

    dataset_name = 'seal'
    batch_dict = {
        'print_word': 32,
        'hand_num': 48,
        'print_num': 48,
        'symbol': 64,
        'hand_word': 64,
        'seal': 64
    }
    dataset = train_datasets.get(dataset_name)
    dataloader = DataLoader(dataset,
                            batch_size=batch_dict.get(dataset_name),
                            shuffle=True,
                            num_workers=4,
                            drop_last=False)

    lr = 1e-3
    params = model.parameters()
    optimizer = optim.Adam(params, lr)
    optimizer.zero_grad()
    batch_cnt = 0
    for epoch in range(config.epochs):
        epoch_loss = 0
        model.train()
        train_acc = 0
        train_acc_cnt = 0

        for i, (img, label, _) in enumerate(dataloader):
            n_correct = 0
            batch_cnt += 1
            train_acc_cnt += 1
            img = img.to(device)
            text, length = converter.encode(label)
            preds = model(img)
            preds_size = torch.IntTensor([preds.size(0)] * img.size(0))
            preds = preds.to('cpu')
            loss = criterion(preds, text, preds_size, length)

            _, preds = preds.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            sim_preds = converter.decode(preds.data,
                                         preds_size.data,
                                         raw=False)

            list1 = [x for x in label]
            for pred, target in zip(sim_preds, list1):
                if pred == target:
                    n_correct += 1

            #loss.backward()
            #optimizer.step()
            #model.zero_grad()

            loss.backward()
            if (i + 1) % 4:
                optimizer.step()
                optimizer.zero_grad()

            epoch_loss += loss.item()
            train_acc += n_correct / len(list1)

            if (i + 1) % 4 == 0:
                print("epoch: {:<3d}, dataset:{:<8}, batch: {:<3d},  batch loss: {:4f}, epoch loss: {:4f}, acc: {}".\
                    format(epoch, dataset_name, i, loss.item(), epoch_loss, n_correct/len(list1)))
                writer.add_scalar('data/train_loss', loss.item(), batch_cnt)
                writer.add_scalar('data/train_acc', n_correct / len(list1),
                                  batch_cnt)

        print('train_average_acc is: {:.3f}'.format(train_acc / train_acc_cnt))
        writer.add_scalar('data/valid_{}acc'.format(dataset_name),
                          train_acc / train_acc_cnt, batch_cnt)
        '''
        dataset_names = [dataset_name]
        accs, valid_losses = valid(model, criterion, converter, device, valid_datasets, dataset_names)

        acc, valid_loss = accs.get(dataset_name), valid_losses.get(dataset_name)
        writer.add_scalar('data/valid_{}acc'.format(dataset_name), acc, batch_cnt)
        writer.add_scalar('data/valid_{}loss'.format(dataset_name), valid_loss, batch_cnt)
        '''
        if epoch % 3 == 0:
            torch.save(
                model.state_dict(),
                '/home/chen-ubuntu/Desktop/checks_dataset/tmp_pths/allseal_lr3_bat512_expaug_epoch_{}_acc{:4f}.pth'
                .format(epoch + 1, train_acc / train_acc_cnt))

        if train_acc / train_acc_cnt > 0.8:
            torch.save(
                model.state_dict(),
                '/home/chen-ubuntu/Desktop/checks_dataset/tmp_pths/allseal_lr3_bat512_expaug_epoch{}_acc{:4f}.pth'
                .format(epoch + 1, train_acc / train_acc_cnt))
Esempio n. 9
0
    ],
                         key=lambda target: target.skeleton_type)
    hf = h5py.File(os.path.join(cfg.INFERENCE.H5.PATH, 'source.h5'), 'w')
    g1 = hf.create_group('group1')
    source_pos = torch.stack([data.pos for data in test_data], dim=0)
    g1.create_dataset('l_joint_pos_2', data=source_pos[:, :3])
    g1.create_dataset('r_joint_pos_2', data=source_pos[:, 3:])
    hf.close()
    print('Source H5 file saved!')

    # Create model
    model = getattr(model, cfg.MODEL.NAME)().to(device)

    # Load checkpoint
    if cfg.MODEL.CHECKPOINT is not None:
        model.load_state_dict(torch.load(cfg.MODEL.CHECKPOINT))

    # store initial z
    model.eval()
    z_all = []
    for batch_idx, data_list in enumerate(test_loader):
        for target_idx, target in enumerate(test_target):
            # fetch target
            target_list = [target for data in data_list]
            # forward
            z = model.encode(
                Batch.from_data_list(data_list).to(device)).detach()
            #             z = torch.empty(Batch.from_data_list(target_list).x.size(0), 64).normal_(mean=0, std=0.005).to(device)
            z.requires_grad = True
            z_all.append(z)
Esempio n. 10
0
    print("Using device {}".format(device))

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    cfg_params = utils.read_data_cfg(args.cfgfile)

    train_txt = cfg_params["train"]
    test_txt = cfg_params["test"]
    backup_dir = cfg_params["backup"]

    if args.load_model is not None:
        print("Loading model from %s." % args.load_model)
        model = models.model.UNet(args.im_size, args.kernel_size)
        model.load_state_dict(torch.load(args.load_model))
    elif args.test:
        print("Missing model file for evaluating test set.")
        exit()
    else:
        model = models.model.UNet(args.im_size, args.kernel_size)

    # datasets and dataloaders.
    if not args.test:
        train_dataset = IGVCDataset(
            train_txt,
            im_size=args.im_size,
            split="train",
            transform=transform,
            val_samples=args.val_samples,
        )
Esempio n. 11
0
        # image_copy.show()
        # image_copy.save(os.path.join(images_path, f"faster_rcnn/{attempt}/images/{name}.png"))
        print(f"{name}, time: {elapsed_time}")
        plt.imshow(image_copy)
        plt.show()
        break


if __name__ == "__main__":
    # torch.manual_seed(4)
    from models import model

    attempt = 7
    model_name = "faster_rcnn_7_30.pt"

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Running on {device}...")

    model.load_state_dict(
        torch.load(os.path.join(models_path, "faster_rcnn_7_30.pt"),
                   map_location=device))
    dataset = MyTestDataset(split='stage1_test',
                            transforms=get_test_transforms(rescale_size=(256,
                                                                         256)))
    test_loader = DataLoader(dataset,
                             batch_size=1,
                             num_workers=0,
                             shuffle=True)

    predict(model, test_loader)
Esempio n. 12
0
def train(epoch, model, criterion_cont, criterion_trip, criterion_sim,
          criterion_l2, criterion_label, optimizer, scheduler, trainLoader,
          device, cont_iter):
    model.train(True)
    losses = AverageMeter()
    cont_losses = AverageMeter()
    trip_losses = AverageMeter()
    sim_losses = AverageMeter()
    label_losses = AverageMeter()
    scheduler.step()
    print('lr:', optimizer.state_dict()['param_groups'][0]['lr'])
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    if lr <= 0.0001:
        checkpoint = torch.load('./' + args.save_dir + '/best_model.pth.tar')
        model.load_state_dict(checkpoint['state_dict'])
    total = 0.0
    total_correct = 0.0
    for batch_idx, (imgs, pids, bags_label) in tqdm(enumerate(trainLoader)):
        #imgs, pids, bags_label = imgs.to(device=device, dtype=torch.float), \
        #                         pids.to(device=device, dtype=torch.long), \
        #                         bags_label.to(device=device, dtype=torch.long)
        imgs, pids, bags_label = imgs.cuda(), pids.cuda(), bags_label.cuda()
        imgs = imgs.float()
        pids = pids.long()
        bags_label = bags_label.long()
        # identification
        #if cont_iter == 100000:
        #    scheduler.step()
        optimizer.zero_grad()
        output_ident, output_all, ident_features, output_bag = model(imgs)
        _, predicted = torch.max(output_bag, 1)
        #print('predicted:')
        #print(predicted)
        #print('label:')
        #print(bags_label)
        correct = (predicted == bags_label).sum()
        total += imgs.shape[0]
        total_correct += correct
        #print('total_correct: total:',total_correct,total)
        #in the input find the image without bag
        dict_hasBag = dict()
        for i in range(len(pids)):
            if bags_label[i] == 0:
                dict_hasBag[pids[i].data.cpu().item()] = i
        imgs_proj_without_bag = torch.zeros(imgs.shape, device=device)
        for i in range(len(pids)):
            imgs_proj_without_bag[i] = imgs[dict_hasBag[
                pids[i].data.cpu().item()]]
        # triplet loss
        label_loss = criterion_label(output_bag, bags_label)
        trip_loss, sim_loss = criterion_trip(ident_features, pids)
        #sim_loss = criterion_sim(ident_features, pids)
        cont_loss_withoutBag = criterion_l2(output_ident,
                                            imgs_proj_without_bag)
        cont_loss_withBag = criterion_l2(output_all, imgs)
        cont_loss = cont_loss_withBag + cont_loss_withoutBag
        if epoch < 10:
            loss = trip_loss + sim_loss * 0.1 + cont_loss * 500 + label_loss * 0.05
        else:
            loss = trip_loss + sim_loss * 0.1 + cont_loss * 500 + label_loss * 0.05

        loss.backward()
        optimizer.step()

        # loss to tensorboardx
        losses.update(loss.item())
        trip_losses.update(trip_loss.item())
        sim_losses.update(sim_loss.item())
        cont_losses.update(cont_loss.item())
        label_losses.update(label_loss.item())
        writer.add_scalar("Train/Loss", losses.val, cont_iter)
        writer.add_scalar("Train/trip_Loss", trip_losses.val, cont_iter)
        writer.add_scalar("Train/sim_Loss", sim_losses.val, cont_iter)
        writer.add_scalar("Train/cont_Loss", cont_losses.val, cont_iter)
        writer.add_scalar("Train/label_loss", label_losses.val, cont_iter)
        cont_iter += 1

        if (cont_iter + 1) % 50 == 0:
            print("iter {}\t Loss {:.4f} ({:.4f}) "
                  "trip_loss {:.4f} ({:.4f}) "
                  "sim_loss {:.4f} ({:.4f}) "
                  "cont_loss {:.4f} ({:.4f})"
                  "label_loss {:.4f} ({:.4f})"
                  "total_correct_rate ({:.5f})".format(
                      cont_iter, losses.val, losses.avg, trip_losses.val,
                      trip_losses.avg, sim_losses.val, sim_losses.avg,
                      cont_losses.val, cont_losses.avg, label_losses.val,
                      label_losses.avg,
                      total_correct.float() / total))
            train_f.write("iter {}\t Loss {:.4f} ({:.4f}) "
                          "trip_loss {:.4f} ({:.4f}) "
                          "sim_loss {:.4f} ({:.4f}) "
                          "cont_loss {:.4f} ({:.4f})"
                          "label_loss {:.4f} ({:.4f})"
                          "total_correct_rate ({:.5f})".format(
                              cont_iter, losses.val, losses.avg,
                              trip_losses.val, trip_losses.avg, sim_losses.val,
                              sim_losses.avg, cont_losses.val, cont_losses.avg,
                              label_losses.val, label_losses.avg,
                              total_correct.float() * 1.0 / total))
            train_f.write('\n')
    return cont_iter
Esempio n. 13
0
    root='G:/LJH/DATASETS/flower_photos', transform=train_transform)

train_dataset, valid_dataset = train_test_split(dataset,
                                                test_size=0.2,
                                                random_state=0)
print(len(train_dataset))
print(len(valid_dataset))
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=0)  #Batch Size定义:一次训练所选取的样本数。 Batch Size的大小影响模型的优化程度和速度。
valid_loader = DataLoader(
    valid_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=0)  #Batch Size定义:一次训练所选取的样本数。 Batch Size的大小影响模型的优化程度和速度。

model = model.MobileNetV3_large()
model.load_state_dict(torch.load('weigths/best.pkl'))
model.conv4 = Conv2d(1280, 5, kernel_size=(1, 1),
                     stride=(1, 1))  #修改最后一层输出的分类个数
model.to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
loss_func = torch.nn.CrossEntropyLoss()
avg_loss = []
avg_acc = []


def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)