def main(train_args):
    net = PSPNet(num_classes=cityscapes.num_classes).cuda()

    if len(train_args['snapshot']) == 0:
        net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse-extra)-psp_net', 'xx.pth')))
        curr_epoch = 1
        train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print 'training resumes from ' + train_args['snapshot']
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                                     'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                                     'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_simul_transform = simul_transforms.Compose([
        simul_transforms.RandomSized(train_args['input_size']),
        simul_transforms.RandomRotate(10),
        simul_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = simul_transforms.Scale(train_args['input_size'])
    train_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.ToTensor()

    train_set = cityscapes.CityScapes('coarse', 'train', simul_transform=train_simul_transform,
                                      transform=train_input_transform,
                                      target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=train_args['train_batch_size'], num_workers=8, shuffle=True)
    val_set = cityscapes.CityScapes('coarse', 'val', simul_transform=val_simul_transform, transform=val_input_transform,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=train_args['val_batch_size'], num_workers=8, shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=True, ignore_index=cityscapes.ignore_label).cuda()

    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], momentum=train_args['momentum'])

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n')

    train(train_loader, net, criterion, optimizer, curr_epoch, train_args, val_loader, restore_transform, visualize)
예제 #2
0
def main():
    net = FCN8ResNet(num_classes=num_classes).cuda()
    if len(train_args['snapshot']) == 0:
        curr_epoch = 0
    else:
        print 'training resumes from ' + train_args['snapshot']
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1])
        train_record['best_val_loss'] = float(split_snapshot[3])
        train_record['corr_mean_iu'] = float(split_snapshot[6])
        train_record['corr_epoch'] = curr_epoch

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.RandomCrop(train_args['input_size']),
        simul_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.CenterCrop(train_args['input_size'])
    ])
    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = standard_transforms.Compose([
        expanded_transforms.MaskToTensor(),
        expanded_transforms.ChangeLabel(ignored_label, num_classes - 1)
    ])
    restore_transform = standard_transforms.Compose([
        expanded_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    train_set = CityScapes('train',
                           simul_transform=train_simul_transform,
                           transform=img_transform,
                           target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=train_args['batch_size'],
                              num_workers=16,
                              shuffle=True)
    val_set = CityScapes('val',
                         simul_transform=val_simul_transform,
                         transform=img_transform,
                         target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=val_args['batch_size'],
                            num_workers=16,
                            shuffle=False)

    weight = torch.ones(num_classes)
    weight[num_classes - 1] = 0
    criterion = CrossEntropyLoss2d(weight).cuda()

    # don't use weight_decay for bias
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and 'fconv' in name
        ],
        'lr':
        2 * train_args['new_lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and 'fconv' in name
        ],
        'lr':
        train_args['new_lr'],
        'weight_decay':
        train_args['weight_decay']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and 'fconv' not in name
        ],
        'lr':
        2 * train_args['pretrained_lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and 'fconv' not in name
        ],
        'lr':
        train_args['pretrained_lr'],
        'weight_decay':
        train_args['weight_decay']
    }],
                          momentum=0.9,
                          nesterov=True)

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['new_lr']
        optimizer.param_groups[1]['lr'] = train_args['new_lr']
        optimizer.param_groups[2]['lr'] = 2 * train_args['pretrained_lr']
        optimizer.param_groups[3]['lr'] = train_args['pretrained_lr']

    if not os.path.exists(ckpt_path):
        os.mkdir(ckpt_path)
    if not os.path.exists(os.path.join(ckpt_path, exp_name)):
        os.mkdir(os.path.join(ckpt_path, exp_name))

    for epoch in range(curr_epoch, train_args['epoch_num']):
        train(train_loader, net, criterion, optimizer, epoch)
        validate(val_loader, net, criterion, optimizer, epoch,
                 restore_transform)
예제 #3
0
def main(train_args):
    net = FCN8s(num_classes=cityscapes.num_classes, caffe=True).cuda()

    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print 'training resumes from ' + train_args['snapshot']
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }

    net.train()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    short_size = int(min(train_args['input_size']) / 0.875)
    train_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(short_size),
        simul_transforms.RandomCrop(train_args['input_size']),
        simul_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(short_size),
        simul_transforms.CenterCrop(train_args['input_size'])
    ])
    input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.Lambda(lambda x: x.div_(255)),
        standard_transforms.ToPILImage(),
        extended_transforms.FlipChannels()
    ])
    visualize = standard_transforms.ToTensor()

    train_set = cityscapes.CityScapes('fine',
                                      'train',
                                      simul_transform=train_simul_transform,
                                      transform=input_transform,
                                      target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=train_args['train_batch_size'],
                              num_workers=8,
                              shuffle=True)
    val_set = cityscapes.CityScapes('fine',
                                    'val',
                                    simul_transform=val_simul_transform,
                                    transform=input_transform,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=train_args['val_batch_size'],
                            num_workers=8,
                            shuffle=False)

    criterion = CrossEntropyLoss2d(
        size_average=False, ignore_index=cityscapes.ignore_label).cuda()

    optimizer = optim.Adam([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * train_args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        train_args['lr'],
        'weight_decay':
        train_args['weight_decay']
    }],
                           betas=(train_args['momentum'], 0.999))

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(train_args) + '\n\n')

    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  patience=train_args['lr_patience'],
                                  min_lr=1e-10,
                                  verbose=True)
    for epoch in range(curr_epoch, train_args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, train_args)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch,
                            train_args, restore_transform, visualize)
        scheduler.step(val_loss)