Beispiel #1
0
def train_fcn(train_img_path, pths_path, batch_size, lr, num_workers,
              epoch_iter, interval):
    #file_num = 1056 #暂定
    if (not os.path.exists(pths_path)):
        os.makedirs(pths_path)

    trainset = fcn_dataset(train_img_path)
    train_loader = data.DataLoader(trainset, batch_size=batch_size, \
                                   shuffle=True, num_workers=num_workers, drop_last=False)
    #criterion = cross_entropy2d()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #device = torch.device("cpu")
    vgg_model = VGGNet(pretrained=False, requires_grad=True, remove_fc=True)
    fcn_model = FCN8s(pretrained_net=vgg_model, n_class=2)
    vgg_model.to(device)
    fcn_model.to(device)

    data_parallel = False
    #model.to(device)
    optimizer = torch.optim.Adam(fcn_model.parameters(), lr=lr)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=[epoch_iter // 2],
                                         gamma=0.1)
    #criterion = BinaryDiceLoss()

    for epoch in range(epoch_iter):
        fcn_model.train()

        epoch_loss = 0
        epoch_time = time.time()
        for i, (img, mask) in enumerate(train_loader):
            start_time = time.time()
            img, mask = img.to(device), mask.to(device)

            output = fcn_model(img)
            #loss = nn.BCEWithLogitsLoss(output, mask)
            loss = cross_entropy2d(output, mask)
            #loss = get_dice_loss(output, mask)
            #loss = criterion(output, mask)
            #loss /= len(img)
            epoch_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            print('Epoch is [{}/{}], time consumption is {:.8f}, batch loss is {:.8f}'.format(\
                epoch+1, epoch_iter, time.time()-start_time, loss.item()))

        print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(
            epoch_loss / len(img),
            time.time() - epoch_time))
        print(time.asctime(time.localtime(time.time())))
        print('=' * 50)

        if (epoch + 1) % interval == 0:
            state_dict = fcn_model.module.state_dict(
            ) if data_parallel else fcn_model.state_dict()
            torch.save(
                state_dict,
                os.path.join(pths_path,
                             'model_epoch_{}.pth'.format(epoch + 1)))
def main():
    # args = parse_args()

    torch.backends.cudnn.benchmark = True
    os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

    # # if args.seed:
    # random.seed(args.seed)
    # np.random.seed(args.seed)
    # torch.manual_seed(args.seed)
    # # if args.gpu:
    # torch.cuda.manual_seed_all(args.seed)
    seed = 63
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # if args.gpu:
    torch.cuda.manual_seed_all(seed)

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    # train_transforms = transforms.Compose([
    # 	transforms.RandomCrop(args['crop_size']),
    # 	transforms.RandomRotation(90),
    # 	transforms.RandomHorizontalFlip(p=0.5),
    # 	transforms.RandomVerticalFlip(p=0.5),

    # 	])
    short_size = int(min(args['input_size']) / 0.875)
    # val_transforms = transforms.Compose([
    # 	transforms.Scale(short_size, interpolation=Image.NEAREST),
    # 	# joint_transforms.Scale(short_size),
    # 	transforms.CenterCrop(args['input_size'])
    # 	])
    train_joint_transform = joint_transforms.Compose([
        # joint_transforms.Scale(short_size),
        joint_transforms.RandomCrop(args['crop_size']),
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomRotate(90)
    ])
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    input_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(*mean_std)])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = transforms.Compose(
        [extended_transforms.DeNormalize(*mean_std),
         transforms.ToPILImage()])
    visualize = transforms.ToTensor()

    train_set = cityscapes.CityScapes('train',
                                      joint_transform=train_joint_transform,
                                      transform=input_transform,
                                      target_transform=target_transform)
    # train_set = cityscapes.CityScapes('train', transform=train_transforms)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=8,
                              shuffle=True)
    val_set = cityscapes.CityScapes('val',
                                    joint_transform=val_joint_transform,
                                    transform=input_transform,
                                    target_transform=target_transform)
    # val_set = cityscapes.CityScapes('val', transform=val_transforms)
    val_loader = DataLoader(val_set,
                            batch_size=args['val_batch_size'],
                            num_workers=8,
                            shuffle=True)

    print(len(train_loader), len(val_loader))

    # sdf

    vgg_model = VGGNet(requires_grad=True, remove_fc=True)
    net = FCN8s(pretrained_net=vgg_model,
                n_class=cityscapes.num_classes,
                dropout_rate=0.4)
    # net.apply(init_weights)
    criterion = nn.CrossEntropyLoss(ignore_index=cityscapes.ignore_label)

    optimizer = optim.Adam(net.parameters(), lr=1e-4)

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min', patience=args['lr_patience'], min_lr=1e-10)

    vgg_model = vgg_model.to(device)
    net = net.to(device)

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9][:-4])
        }

    criterion.to(device)

    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, device, criterion, optimizer, epoch, args)
        val_loss = validate(val_loader, net, device, criterion, optimizer,
                            epoch, args, restore_transform, visualize)
        scheduler.step(val_loss)
def main():

    torch.backends.cudnn.benchmark = True
    os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

    vgg_model = VGGNet(requires_grad=True, remove_fc=True)
    net = FCN8s(pretrained_net=vgg_model,
                n_class=cityscapes.num_classes,
                dropout_rate=0.4)
    print('load model ' + args['snapshot'])

    vgg_model = vgg_model.to(device)
    net = net.to(device)

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.load_state_dict(
        torch.load(os.path.join(ckpt_path, args['exp_name'],
                                args['snapshot'])))
    net.eval()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    short_size = int(min(args['input_size']) / 0.875)
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(*mean_std)])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = transforms.Compose(
        [extended_transforms.DeNormalize(*mean_std),
         transforms.ToPILImage()])

    # test_set = cityscapes.CityScapes('test', transform=test_transform)

    test_set = cityscapes.CityScapes('test',
                                     joint_transform=val_joint_transform,
                                     transform=test_transform,
                                     target_transform=target_transform)

    test_loader = DataLoader(test_set,
                             batch_size=1,
                             num_workers=8,
                             shuffle=False)

    transform = transforms.ToPILImage()

    check_mkdir(os.path.join(ckpt_path, args['exp_name'], 'test'))

    gts_all, predictions_all = [], []
    count = 0
    for vi, data in enumerate(test_loader):
        # img_name, img = data
        img_name, img, gts = data

        img_name = img_name[0]
        # print(img_name)
        img_name = img_name.split('/')[-1]
        # img.save(os.path.join(ckpt_path, args['exp_name'], 'test', img_name))

        img_transform = restore_transform(img[0])
        # img_transform = img_transform.convert('RGB')
        img_transform.save(
            os.path.join(ckpt_path, args['exp_name'], 'test', img_name))
        img_name = img_name.split('_leftImg8bit.png')[0]

        # img = Variable(img, volatile=True).cuda()
        img, gts = img.to(device), gts.to(device)
        output = net(img)

        prediction = output.data.max(1)[1].squeeze_(1).squeeze_(
            0).cpu().numpy()
        prediction_img = cityscapes.colorize_mask(prediction)
        # print(type(prediction_img))
        prediction_img.save(
            os.path.join(ckpt_path, args['exp_name'], 'test',
                         img_name + '.png'))
        # print(ckpt_path, args['exp_name'], 'test', img_name + '.png')

        print('%d / %d' % (vi + 1, len(test_loader)))
        gts_all.append(gts.data.cpu().numpy())
        predictions_all.append(prediction)
        # break

        # if count == 1:
        #     break
        # count += 1
    gts_all = np.concatenate(gts_all)
    predictions_all = np.concatenate(prediction)
    acc, acc_cls, mean_iou, _ = evaluate(predictions_all, gts_all,
                                         cityscapes.num_classes)

    print(
        '-----------------------------------------------------------------------------------------------------------'
    )
    print('[acc %.5f], [acc_cls %.5f], [mean_iu %.5f]' %
          (acc, acc_cls, mean_iu))
def main():
    args = parse_args()

    torch.backends.cudnn.benchmark = True
    os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

    # Random seed for reproducibility
    if args.seed:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if args.gpu:
            torch.cuda.manual_seed_all(args.seed)

    # seed = 63
    # random.seed(seed)
    # np.random.seed(seed)
    # torch.manual_seed(seed)
    # # if args.gpu:
    # torch.cuda.manual_seed_all(seed)

    denoramlize_argument = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_transforms = transforms.Compose([
        transforms.RandomCrop((args.crop_size, args.crop_size)),
        transforms.RandomRotation(90),
        transforms.RandomHorizontalFlip(p=0.5),
    ])

    # train_joint_transform = joint_transforms.Compose([
    # 	# joint_transforms.Scale(img_resize_shape),
    # 	joint_transforms.RandomCrop(args['crop_size']),
    # 	joint_transforms.RandomHorizontallyFlip(),
    # 	joint_transforms.RandomRotate(90)
    # ])

    img_resize_shape = int(min(args.input_size) / 0.8)
    # val_transforms = transforms.Compose([
    # 	transforms.Scale(img_resize_shape, interpolation=Image.NEAREST),
    # 	transforms.CenterCrop(args['input_size'])
    # 	])

    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(img_resize_shape),
        joint_transforms.CenterCrop(args.input_size)
    ])
    input_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = transforms.Compose([
        extended_transforms.DeNormalize(*denoramlize_argument),
        transforms.ToPILImage()
    ])
    visualize = transforms.ToTensor()

    # train_set = games_data.CityScapes('train', joint_transform=train_joint_transform,
    # 								  transform=input_transform, target_transform=target_transform)
    train_set = games_data.CityScapes('train', transform=train_transforms)
    # train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True)
    train_loader = DataLoader(train_set,
                              batch_size=args.training_batch_size,
                              num_workers=8,
                              shuffle=True)
    val_set = games_data.CityScapes('val',
                                    joint_transform=val_joint_transform,
                                    transform=input_transform,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=args.val_batch_size,
                            num_workers=8,
                            shuffle=True)

    print(len(train_loader), len(val_loader))
    # sdf

    # Load pretrained VGG model
    vgg_model = VGGNet(requires_grad=True, remove_fc=True)

    # FCN architecture load
    model = FCN8s(pretrained_net=vgg_model,
                  n_class=games_data.num_classes,
                  dropout_rate=0.4)

    # Loss function
    criterion = nn.CrossEntropyLoss(ignore_index=games_data.ignore_label)

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Create directory for checkpoints
    exist_directory(ckpt_path)
    exist_directory(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'min',
                                                     patience=args.lr_patience,
                                                     min_lr=1e-10)

    # Send model to CUDA device
    vgg_model = vgg_model.to(device)
    model = model.to(device)

    # Use if more than 1 GPU
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    if len(args.snapshot) == 0:
        curr_epoch = 1
        best_args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        model.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        best_args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9][:-4])
        }

    criterion.to(device)

    for epoch in range(curr_epoch, args.epochs + 1):
        train(train_loader, model, device, criterion, optimizer, epoch, args)
        val_loss = validate(val_loader, model, device, criterion, optimizer,
                            epoch, args, restore_transform, visualize)
        scheduler.step(val_loss)