def main(train_args): net = PSPNet(num_classes=cityscapes.num_classes).cuda() if len(train_args['snapshot']) == 0: net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse-extra)-psp_net', 'xx.pth'))) curr_epoch = 1 train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} else: print 'training resumes from ' + train_args['snapshot'] net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot']))) split_snapshot = train_args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])} net.train() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_simul_transform = simul_transforms.Compose([ simul_transforms.RandomSized(train_args['input_size']), simul_transforms.RandomRotate(10), simul_transforms.RandomHorizontallyFlip() ]) val_simul_transform = simul_transforms.Scale(train_args['input_size']) train_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) val_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.Compose([ extended_transforms.DeNormalize(*mean_std), standard_transforms.ToPILImage(), ]) visualize = standard_transforms.ToTensor() train_set = cityscapes.CityScapes('coarse', 'train', simul_transform=train_simul_transform, transform=train_input_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=train_args['train_batch_size'], num_workers=8, shuffle=True) val_set = cityscapes.CityScapes('coarse', 'val', simul_transform=val_simul_transform, transform=val_input_transform, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=train_args['val_batch_size'], num_workers=8, shuffle=False) criterion = CrossEntropyLoss2d(size_average=True, ignore_index=cityscapes.ignore_label).cuda() optimizer = optim.SGD([ {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 'lr': 2 * train_args['lr']}, {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']} ], momentum=train_args['momentum']) if len(train_args['snapshot']) > 0: optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * train_args['lr'] optimizer.param_groups[1]['lr'] = train_args['lr'] check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n') train(train_loader, net, criterion, optimizer, curr_epoch, train_args, val_loader, restore_transform, visualize)
def main(): net = FCN8ResNet(num_classes=num_classes).cuda() if len(train_args['snapshot']) == 0: curr_epoch = 0 else: print 'training resumes from ' + train_args['snapshot'] net.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, train_args['snapshot']))) split_snapshot = train_args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) train_record['best_val_loss'] = float(split_snapshot[3]) train_record['corr_mean_iu'] = float(split_snapshot[6]) train_record['corr_epoch'] = curr_epoch net.train() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_simul_transform = simul_transforms.Compose([ simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)), simul_transforms.RandomCrop(train_args['input_size']), simul_transforms.RandomHorizontallyFlip() ]) val_simul_transform = simul_transforms.Compose([ simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)), simul_transforms.CenterCrop(train_args['input_size']) ]) img_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) target_transform = standard_transforms.Compose([ expanded_transforms.MaskToTensor(), expanded_transforms.ChangeLabel(ignored_label, num_classes - 1) ]) restore_transform = standard_transforms.Compose([ expanded_transforms.DeNormalize(*mean_std), standard_transforms.ToPILImage() ]) train_set = CityScapes('train', simul_transform=train_simul_transform, transform=img_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=train_args['batch_size'], num_workers=16, shuffle=True) val_set = CityScapes('val', simul_transform=val_simul_transform, transform=img_transform, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=val_args['batch_size'], num_workers=16, shuffle=False) weight = torch.ones(num_classes) weight[num_classes - 1] = 0 criterion = CrossEntropyLoss2d(weight).cuda() # don't use weight_decay for bias optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' and 'fconv' in name ], 'lr': 2 * train_args['new_lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' and 'fconv' in name ], 'lr': train_args['new_lr'], 'weight_decay': train_args['weight_decay'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' and 'fconv' not in name ], 'lr': 2 * train_args['pretrained_lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' and 'fconv' not in name ], 'lr': train_args['pretrained_lr'], 'weight_decay': train_args['weight_decay'] }], momentum=0.9, nesterov=True) if len(train_args['snapshot']) > 0: optimizer.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * train_args['new_lr'] optimizer.param_groups[1]['lr'] = train_args['new_lr'] optimizer.param_groups[2]['lr'] = 2 * train_args['pretrained_lr'] optimizer.param_groups[3]['lr'] = train_args['pretrained_lr'] if not os.path.exists(ckpt_path): os.mkdir(ckpt_path) if not os.path.exists(os.path.join(ckpt_path, exp_name)): os.mkdir(os.path.join(ckpt_path, exp_name)) for epoch in range(curr_epoch, train_args['epoch_num']): train(train_loader, net, criterion, optimizer, epoch) validate(val_loader, net, criterion, optimizer, epoch, restore_transform)
def main(train_args): net = FCN8s(num_classes=cityscapes.num_classes, caffe=True).cuda() if len(train_args['snapshot']) == 0: curr_epoch = 1 train_args['best_record'] = { 'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0 } else: print 'training resumes from ' + train_args['snapshot'] net.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, train_args['snapshot']))) split_snapshot = train_args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 train_args['best_record'] = { 'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11]) } net.train() mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0]) short_size = int(min(train_args['input_size']) / 0.875) train_simul_transform = simul_transforms.Compose([ simul_transforms.Scale(short_size), simul_transforms.RandomCrop(train_args['input_size']), simul_transforms.RandomHorizontallyFlip() ]) val_simul_transform = simul_transforms.Compose([ simul_transforms.Scale(short_size), simul_transforms.CenterCrop(train_args['input_size']) ]) input_transform = standard_transforms.Compose([ extended_transforms.FlipChannels(), standard_transforms.ToTensor(), standard_transforms.Lambda(lambda x: x.mul_(255)), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.Compose([ extended_transforms.DeNormalize(*mean_std), standard_transforms.Lambda(lambda x: x.div_(255)), standard_transforms.ToPILImage(), extended_transforms.FlipChannels() ]) visualize = standard_transforms.ToTensor() train_set = cityscapes.CityScapes('fine', 'train', simul_transform=train_simul_transform, transform=input_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=train_args['train_batch_size'], num_workers=8, shuffle=True) val_set = cityscapes.CityScapes('fine', 'val', simul_transform=val_simul_transform, transform=input_transform, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=train_args['val_batch_size'], num_workers=8, shuffle=False) criterion = CrossEntropyLoss2d( size_average=False, ignore_index=cityscapes.ignore_label).cuda() optimizer = optim.Adam([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' ], 'lr': 2 * train_args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' ], 'lr': train_args['lr'], 'weight_decay': train_args['weight_decay'] }], betas=(train_args['momentum'], 0.999)) if len(train_args['snapshot']) > 0: optimizer.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * train_args['lr'] optimizer.param_groups[1]['lr'] = train_args['lr'] check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open( os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n') scheduler = ReduceLROnPlateau(optimizer, 'min', patience=train_args['lr_patience'], min_lr=1e-10, verbose=True) for epoch in range(curr_epoch, train_args['epoch_num'] + 1): train(train_loader, net, criterion, optimizer, epoch, train_args) val_loss = validate(val_loader, net, criterion, optimizer, epoch, train_args, restore_transform, visualize) scheduler.step(val_loss)