예제 #1
0
파일: train.py 프로젝트: w6688j/CBAM
def main(args):
    if 0 == len(args.resume):
        logger = Logger('./logs/' + args.model + '.log')
    else:
        logger = Logger('./logs/' + args.model + '.log', True)

    logger.append(vars(args))

    if args.display:
        writer = SummaryWriter()
    else:
        writer = None

    gpus = args.gpu.split(',')
    data_transforms = {
        'train':
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val':
        transforms.Compose([
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }

    train_datasets = datasets.ImageFolder(os.path.join(args.data_root, 't256'),
                                          data_transforms['train'])
    val_datasets = datasets.ImageFolder(os.path.join(args.data_root, 'v256'),
                                        data_transforms['val'])
    train_dataloaders = torch.utils.data.DataLoader(
        train_datasets,
        batch_size=args.batch_size * len(gpus),
        shuffle=True,
        num_workers=8)
    val_dataloaders = torch.utils.data.DataLoader(val_datasets,
                                                  batch_size=1024,
                                                  shuffle=False,
                                                  num_workers=8)

    if args.debug:
        x, y = next(iter(train_dataloaders))
        logger.append([x, y])

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    is_use_cuda = torch.cuda.is_available()
    cudnn.benchmark = True

    if 'resnet50' == args.model.split('_')[0]:
        my_model = models.resnet50(pretrained=False)
    elif 'resnet50-cbam' == args.model.split('_')[0]:
        my_model = resnet_cbam.resnet50_cbam(pretrained=False)
    elif 'resnet101' == args.model.split('_')[0]:
        my_model = models.resnet101(pretrained=False)
    else:
        raise ModuleNotFoundError

    # my_model.apply(fc_init)
    if is_use_cuda and 1 == len(gpus):
        my_model = my_model.cuda()
    elif is_use_cuda and 1 < len(gpus):
        my_model = nn.DataParallel(my_model.cuda())

    loss_fn = [nn.CrossEntropyLoss()]
    optimizer = optim.SGD(my_model.parameters(),
                          lr=0.1,
                          momentum=0.9,
                          weight_decay=1e-4)
    lr_schedule = lr_scheduler.MultiStepLR(optimizer,
                                           milestones=[30, 60],
                                           gamma=0.1)  #

    metric = [ClassErrorMeter([1, 5], True)]
    start_epoch = 0
    num_epochs = 90

    my_trainer = Trainer(my_model, args.model, loss_fn, optimizer, lr_schedule, 500, is_use_cuda, train_dataloaders, \
                         val_dataloaders, metric, start_epoch, num_epochs, args.debug, logger, writer)
    my_trainer.fit()
    logger.append('Optimize Done!')
예제 #2
0
def main(args):

    # region -----------------------------------------记录训练日志-----------------------------------------

    if 0 == len(args.resume):
        logger = Logger('./logs/' + args.model + '.log')
    else:
        logger = Logger('./logs/' + args.model + '.log', True)

    logger.append(vars(args))

    # if args.display:
    #     writer = SummaryWriter()
    # else:
    #     writer = None
    writer = SummaryWriter()

    # endregion

    # region ----------------------------------------数据预处理配置----------------------------------------
    data_transforms = {
        'train':
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val':
        transforms.Compose([
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }
    # endregion

    # region -------------------------------------------数据加载-------------------------------------------

    train_datasets = datasets.CIFAR10(root="data",
                                      train=True,
                                      download=True,
                                      transform=transforms.ToTensor())
    val_datasets = datasets.CIFAR10(root="data",
                                    train=False,
                                    download=True,
                                    transform=transforms.ToTensor())

    # train_datasets = datasets.ImageFolder(os.path.join(args.data_root, 't256'), data_transforms['train'])
    # val_datasets   = datasets.ImageFolder(os.path.join(args.data_root, 'v256'), data_transforms['val'])
    train_dataloaders = torch.utils.data.DataLoader(train_datasets,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=4)
    val_dataloaders = torch.utils.data.DataLoader(val_datasets,
                                                  batch_size=1024,
                                                  shuffle=False,
                                                  num_workers=4)
    # endregion

    # region --------------------------------------网络无关配置设置----------------------------------------
    # 记录日志
    if args.debug:
        x, y = next(iter(train_dataloaders))
        logger.append([x, y])

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    is_use_cuda = torch.cuda.is_available()
    cudnn.benchmark = True

    # 设置基础网络
    if 'resnet50' == args.model.split('_')[0]:
        my_model = models.resnet50(pretrained=False)
    elif 'resnet50-cbam' == args.model.split('_')[0]:
        my_model = resnet_cbam.resnet50_cbam(pretrained=False)
    elif 'resnet101' == args.model.split('_')[0]:
        my_model = models.resnet101(pretrained=False)
    else:
        raise ModuleNotFoundError

    # endregion

    # region --------------------------------------网络训练配置设置----------------------------------------
    # 损失函数设定
    loss_fn = [nn.CrossEntropyLoss()]
    # 优化器设置
    optimizer = optim.SGD(my_model.parameters(),
                          lr=0.1,
                          momentum=0.9,
                          weight_decay=1e-4)
    # 学习率优化函数设置
    lr_schedule = lr_scheduler.MultiStepLR(optimizer,
                                           milestones=[30, 60],
                                           gamma=0.1)
    # 累计误差TOP5
    metric = [ClassErrorMeter([1, 5], True)]
    # 迭代次数
    epoch = int(args.epoch)
    # 传入训练器
    my_trainer = Trainer(my_model, args.model, loss_fn, optimizer, lr_schedule,
                         500, is_use_cuda, train_dataloaders, val_dataloaders,
                         metric, 0, epoch, args.debug, logger, writer)
    # 训练
    my_trainer.fit()
    logger.append('训练完毕')
        self.power = power
        super(PolyLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        return [base_lr * (1 - self.last_epoch/self.max_iter) ** self.power
                for base_lr in self.base_lrs]


NUM_CLASS=19
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device:{}'.format(device))
args = parse_args()
kwargs = {'num_workers': 0, 'pin_memory': True}
train_loader, val_loader, test_loader, num_class = make_data_loader(args, **kwargs)

my_model = resnet_cbam.resnet50_cbam(pretrained=False)
my_model = my_model.to(device)

weight = None
criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode='ce')
optimizer = torch.optim.SGD(my_model.parameters(), lr=args.lr, momentum=0, weight_decay=args.weight_decay)

optimizer_lr_scheduler = PolyLR(optimizer, max_iter=args.epochs, power=0.9)

evaluator = Evaluator(NUM_CLASS)

def train(epoch, optimizer, train_loader):
    my_model.train()
    for iteration, batch in enumerate(train_loader):
        image, target = batch['image'], batch['label']
        inputs = image.to(device)