예제 #1
0
def main_worker(gpu, ngpus_per_node, args):
    global best_pred
    args.gpu = gpu
    args.rank = args.rank * ngpus_per_node + gpu
    print('rank: {} / {}'.format(args.rank, args.world_size))
    dist.init_process_group(backend=args.dist_backend,
                            init_method=args.dist_url,
                            world_size=args.world_size,
                            rank=args.rank)
    torch.cuda.set_device(args.gpu)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True
    # data transforms
    input_transform = transform.Compose([
        transform.ToTensor(),
        transform.Normalize([.485, .456, .406], [.229, .224, .225])
    ])
    # dataset
    data_kwargs = {
        'transform': input_transform,
        'base_size': args.base_size,
        'crop_size': args.crop_size
    }
    trainset = get_dataset(args.dataset,
                           split=args.train_split,
                           mode='train',
                           **data_kwargs)
    valset = get_dataset(args.dataset, split='val', mode='val', **data_kwargs)
    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
    val_sampler = torch.utils.data.distributed.DistributedSampler(
        valset, shuffle=False)
    # dataloader
    loader_kwargs = {
        'batch_size': args.batch_size,
        'num_workers': args.workers,
        'pin_memory': True
    }
    trainloader = data.DataLoader(trainset,
                                  sampler=train_sampler,
                                  drop_last=True,
                                  **loader_kwargs)
    valloader = data.DataLoader(valset, sampler=val_sampler, **loader_kwargs)
    nclass = trainset.num_class
    # model
    model_kwargs = {}
    if args.rectify:
        model_kwargs['rectified_conv'] = True
        model_kwargs['rectify_avg'] = args.rectify_avg
    model = get_segmentation_model(args.model,
                                   dataset=args.dataset,
                                   backbone=args.backbone,
                                   aux=args.aux,
                                   se_loss=args.se_loss,
                                   norm_layer=DistSyncBatchNorm,
                                   base_size=args.base_size,
                                   crop_size=args.crop_size,
                                   **model_kwargs)
    if args.gpu == 0:
        print(model)
    # optimizer using different LR
    params_list = [
        {
            'params': model.pretrained.parameters(),
            'lr': args.lr
        },
    ]
    if hasattr(model, 'head'):
        params_list.append({
            'params': model.head.parameters(),
            'lr': args.lr * 10
        })
    if hasattr(model, 'auxlayer'):
        params_list.append({
            'params': model.auxlayer.parameters(),
            'lr': args.lr * 10
        })
    optimizer = torch.optim.SGD(params_list,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # optimizer = torch.optim.Adam(params_list,
    #                             lr=args.lr,
    #                             # momentum=args.momentum,
    #                             weight_decay=args.weight_decay)
    # criterions
    criterion = SegmentationLosses(se_loss=args.se_loss,
                                   aux=args.aux,
                                   nclass=nclass,
                                   se_weight=args.se_weight,
                                   aux_weight=args.aux_weight)
    # distributed data parallel
    model.cuda(args.gpu)
    criterion.cuda(args.gpu)
    model = DistributedDataParallel(model, device_ids=[args.gpu])
    metric = utils.SegmentationMetric(nclass=nclass)

    # resuming checkpoint
    if args.resume is not None:
        if not os.path.isfile(args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.resume))

        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        model.module.load_state_dict(checkpoint['state_dict'])
        '''
        checkpoint = torch.load(args.resume, map_location='cpu')
        args.start_epoch = checkpoint['epoch']
        model.module.load_state_dict(checkpoint['state_dict'])
        model.cuda()
        '''
        if not args.ft:
            optimizer.load_state_dict(checkpoint['optimizer'])
        best_pred = checkpoint['best_pred']
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))
    # clear start epoch if fine-tuning
    if args.ft:
        args.start_epoch = 0

    # lr scheduler
    scheduler = utils.LR_Scheduler_Head(args.lr_scheduler, args.lr,
                                        args.epochs, len(trainloader))
    # train_losses = [2.855, 2.513, 2.275, 2.128, 2.001, 1.875, 1.855, 1.916, 1.987, 1.915, 1.952]
    train_losses = []

    def training(epoch):
        train_sampler.set_epoch(epoch)
        global best_pred
        train_loss = 0.0
        model.train()
        tic = time.time()
        for i, (image, target) in enumerate(trainloader):
            scheduler(optimizer, i, epoch, best_pred)
            optimizer.zero_grad()
            outputs = model(image)
            target = target.cuda(args.gpu)
            loss = criterion(*outputs, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            if i % 100 == 0 and args.gpu == 0:
                iter_per_sec = 100.0 / (
                    time.time() - tic) if i != 0 else 1.0 / (time.time() - tic)
                tic = time.time()
                print('Epoch: {}, Iter: {}, Speed: {:.3f} iter/sec, Train loss: {:.3f}'. \
                      format(epoch, i, iter_per_sec, train_loss / (i + 1)))
        train_losses.append(train_loss / len(trainloader))
        if epoch > 1:
            if train_losses[epoch] < train_losses[epoch - 1]:
                utils.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.module.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'best_pred': new_preds[(epoch - 1) // 10],
                    },
                    args,
                    False,
                    filename='checkpoint_train.pth.tar')
        plt.plot(train_losses)
        plt.xlabel('Epoch')
        plt.ylabel('Train_loss')
        plt.title('Train_Loss')
        plt.grid()
        plt.savefig('./loss_fig/train_losses.pdf')
        plt.savefig('./loss_fig/train_losses.svg')
        plt.close()

    # p_m = [(0.3, 0.05), (0.23, 0.54)]
    # new_preds = [0.175, 0.392]
    p_m = []
    new_preds = []

    def validation(epoch):
        # Fast test during the training using single-crop only
        global best_pred
        is_best = False
        model.eval()
        metric.reset()

        for i, (image, target) in enumerate(valloader):
            with torch.no_grad():
                pred = model(image)[0]
                target = target.cuda(args.gpu)
                metric.update(target, pred)

            if i % 100 == 0:
                all_metircs = metric.get_all()
                all_metircs = utils.torch_dist_sum(args.gpu, *all_metircs)
                pixAcc, mIoU = utils.get_pixacc_miou(*all_metircs)
                if args.gpu == 0:
                    print('pixAcc: %.3f, mIoU1: %.3f' % (pixAcc, mIoU))

        all_metircs = metric.get_all()
        all_metircs = utils.torch_dist_sum(args.gpu, *all_metircs)
        pixAcc, mIoU = utils.get_pixacc_miou(*all_metircs)
        if args.gpu == 0:
            print('pixAcc: %.3f, mIoU2: %.3f' % (pixAcc, mIoU))

            p_m.append((pixAcc, mIoU))
            plt.plot(p_m)
            plt.xlabel('10 Epoch')
            plt.ylabel('pixAcc, mIoU')
            plt.title('pixAcc, mIoU')
            plt.grid()
            plt.legend(('pixAcc', 'mIoU'))

            plt.savefig('./loss_fig/pixAcc_mIoU.pdf')
            plt.savefig('./loss_fig/pixAcc_mIoU.svg')
            plt.close()

            if args.eval: return
            new_pred = (pixAcc + mIoU) / 2
            new_preds.append(new_pred)

            plt.plot(new_preds)
            plt.xlabel('10 Epoch')
            plt.ylabel('new_predication')
            plt.title('new_predication')
            plt.grid()
            plt.savefig('./loss_fig/new_predication.pdf')
            plt.savefig('./loss_fig/new_predication.svg')
            plt.close()

            if new_pred > best_pred:
                is_best = True
                best_pred = new_pred
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_pred': best_pred,
                },
                args,
                is_best,
                filename='checkpoint_train_{}.pth.tar'.format(epoch + 1))

    if args.export:
        if args.gpu == 0:
            torch.save(model.module.state_dict(), args.export + '.pth')
        return

    if args.eval:
        validation(args.start_epoch)
        return

    if args.gpu == 0:
        print('Starting Epoch:', args.start_epoch)
        print('Total Epoches:', args.epochs)

    for epoch in range(args.start_epoch, args.epochs):
        tic = time.time()
        training(epoch)
        if epoch % 10 == 0 or epoch == args.epochs - 1:
            validation(epoch)
        elapsed = time.time() - tic
        if args.gpu == 0:
            print(f'Epoch: {epoch}, Time cost: {elapsed}')

    validation(epoch)
예제 #2
0
def main_worker(gpu, ngpus_per_node, args):
    global best_pred
    args.gpu = gpu
    args.rank = args.rank * ngpus_per_node + gpu
    print('rank: {} / {}'.format(args.rank, args.world_size))
    dist.init_process_group(backend=args.dist_backend,
                            init_method=args.dist_url,
                            world_size=args.world_size,
                            rank=args.rank)
    torch.cuda.set_device(args.gpu)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True
    # data transforms
    input_transform = transform.Compose([
        transform.ToTensor(),
        transform.Normalize([.485, .456, .406], [.229, .224, .225])
    ])
    # dataset
    data_kwargs = {
        'transform': input_transform,
        'base_size': args.base_size,
        'crop_size': args.crop_size
    }
    trainset = get_dataset(args.dataset,
                           split=args.train_split,
                           mode='train',
                           **data_kwargs)
    valset = get_dataset(args.dataset, split='val', mode='val', **data_kwargs)
    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
    val_sampler = torch.utils.data.distributed.DistributedSampler(
        valset, shuffle=False)
    # dataloader
    loader_kwargs = {
        'batch_size': args.batch_size,
        'num_workers': args.workers,
        'pin_memory': True
    }
    trainloader = data.DataLoader(trainset,
                                  sampler=train_sampler,
                                  drop_last=True,
                                  **loader_kwargs)
    valloader = data.DataLoader(valset, sampler=val_sampler, **loader_kwargs)
    nclass = trainset.num_class
    # model
    model_kwargs = {}
    if args.rectify:
        model_kwargs['rectified_conv'] = True
        model_kwargs['rectify_avg'] = args.rectify_avg
    model = get_segmentation_model(args.model,
                                   dataset=args.dataset,
                                   backbone=args.backbone,
                                   aux=args.aux,
                                   se_loss=args.se_loss,
                                   norm_layer=DistSyncBatchNorm,
                                   base_size=args.base_size,
                                   crop_size=args.crop_size,
                                   **model_kwargs)
    if args.gpu == 0:
        print(model)
    # optimizer using different LR
    params_list = [
        {
            'params': model.pretrained.parameters(),
            'lr': args.lr
        },
    ]
    if hasattr(model, 'head'):
        params_list.append({
            'params': model.head.parameters(),
            'lr': args.lr * 10
        })
    if hasattr(model, 'auxlayer'):
        params_list.append({
            'params': model.auxlayer.parameters(),
            'lr': args.lr * 10
        })
    optimizer = torch.optim.SGD(params_list,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # criterions
    criterion = SegmentationLosses(se_loss=args.se_loss,
                                   aux=args.aux,
                                   nclass=nclass,
                                   se_weight=args.se_weight,
                                   aux_weight=args.aux_weight)
    # distributed data parallel
    model.cuda(args.gpu)
    criterion.cuda(args.gpu)
    model = DistributedDataParallel(model, device_ids=[args.gpu])
    metric = utils.SegmentationMetric(nclass=nclass)

    # resuming checkpoint
    if args.resume is not None:
        if not os.path.isfile(args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        model.module.load_state_dict(checkpoint['state_dict'])
        if not args.ft:
            optimizer.load_state_dict(checkpoint['optimizer'])
        best_pred = checkpoint['best_pred']
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))
    # clear start epoch if fine-tuning
    if args.ft:
        args.start_epoch = 0

    # lr scheduler
    scheduler = utils.LR_Scheduler_Head(args.lr_scheduler, args.lr,
                                        args.epochs, len(trainloader))

    def training(epoch):
        global best_pred
        train_loss = 0.0
        model.train()
        tic = time.time()
        for i, (image, target) in enumerate(trainloader):
            scheduler(optimizer, i, epoch, best_pred)
            optimizer.zero_grad()
            outputs = model(image)
            target = target.cuda(args.gpu)
            loss = criterion(*outputs, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            if i % 100 == 0 and args.gpu == 0:
                iter_per_sec = 100.0 / (
                    time.time() - tic) if i != 0 else 1.0 / (time.time() - tic)
                tic = time.time()
                print('Epoch: {}, Iter: {}, Speed: {:.3f} iter/sec, Train loss: {:.3f}'. \
                      format(epoch, i, iter_per_sec, train_loss / (i + 1)))

    def validation(epoch):
        # Fast test during the training using single-crop only
        global best_pred
        is_best = False
        model.eval()
        metric.reset()

        for i, (image, target) in enumerate(valloader):
            with torch.no_grad():
                #correct, labeled, inter, union = eval_batch(model, image, target)
                pred = model(image)[0]
                target = target.cuda(args.gpu)
                metric.update(target, pred)

            pixAcc, mIoU = metric.get()
            if i % 100 == 0 and args.gpu == 0:
                print('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU))

        if args.gpu == 0:
            pixAcc, mIoU = torch_dist_avg(args.gpu, pixAcc, mIoU)
            print('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU))

            new_pred = (pixAcc + mIoU) / 2
            if new_pred > best_pred:
                is_best = True
                best_pred = new_pred
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_pred': best_pred,
                }, args, is_best)

    if args.gpu == 0:
        print('Starting Epoch:', args.start_epoch)
        print('Total Epoches:', args.epochs)

    for epoch in range(args.start_epoch, args.epochs):
        tic = time.time()
        training(epoch)
        if epoch % 10 == 0:
            validation(epoch)
        elapsed = time.time() - tic
        if args.gpu == 0:
            print(f'Epoch: {epoch}, Time cost: {elapsed}')

    validation(epoch)