def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code source_dataset = datasets.__dict__[args.source] train_source_dataset = source_dataset( root=args.source_root, transforms=T.Compose([ T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)), T.ColorJitter(brightness=0.3, contrast=0.3), T.RandomHorizontalFlip(), T.NormalizeAndTranspose(), ]), ) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) target_dataset = datasets.__dict__[args.target] train_target_dataset = target_dataset( root=args.target_root, transforms=T.Compose([ T.RandomResizedCrop(size=args.train_size, ratio=(2., 2.), scale=(0.5, 1.)), T.RandomHorizontalFlip(), T.NormalizeAndTranspose(), ]), ) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_target_dataset = target_dataset( root=args.target_root, split='val', transforms=T.Compose([ T.Resize(image_size=args.test_input_size, label_size=args.test_output_size), T.NormalizeAndTranspose(), ]), ) val_target_loader = DataLoader(val_target_dataset, batch_size=1, shuffle=False, pin_memory=True) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model num_classes = train_source_dataset.num_classes model = models.__dict__[args.arch](num_classes=num_classes).to(device) discriminator = Discriminator(num_classes=num_classes).to(device) # define optimizer and lr scheduler optimizer = SGD(model.get_parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) optimizer_d = Adam(discriminator.parameters(), lr=args.lr_d, betas=(0.9, 0.99)) lr_scheduler = LambdaLR( optimizer, lambda x: args.lr * (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power)) lr_scheduler_d = LambdaLR( optimizer_d, lambda x: (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power)) # optionally resume from a checkpoint if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['model']) discriminator.load_state_dict(checkpoint['discriminator']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) optimizer_d.load_state_dict(checkpoint['optimizer_d']) lr_scheduler_d.load_state_dict(checkpoint['lr_scheduler_d']) args.start_epoch = checkpoint['epoch'] + 1 # define loss function (criterion) criterion = torch.nn.CrossEntropyLoss( ignore_index=args.ignore_label).to(device) dann = DomainAdversarialEntropyLoss(discriminator) interp_train = nn.Upsample(size=args.train_size[::-1], mode='bilinear', align_corners=True) interp_val = nn.Upsample(size=args.test_output_size[::-1], mode='bilinear', align_corners=True) # define visualization function decode = train_source_dataset.decode_target def visualize(image, pred, label, prefix): """ Args: image (tensor): 3 x H x W pred (tensor): C x H x W label (tensor): H x W prefix: prefix of the saving image """ image = image.detach().cpu().numpy() pred = pred.detach().max(dim=0)[1].cpu().numpy() label = label.cpu().numpy() for tensor, name in [ (Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))), "image"), (decode(label), "label"), (decode(pred), "pred") ]: tensor.save(logger.get_image_path("{}_{}.png".format(prefix, name))) if args.phase == 'test': confmat = validate(val_target_loader, model, interp_val, criterion, visualize, args) print(confmat) return # start training best_iou = 0. for epoch in range(args.start_epoch, args.epochs): logger.set_epoch(epoch) print(lr_scheduler.get_lr(), lr_scheduler_d.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, model, interp_train, criterion, dann, optimizer, lr_scheduler, optimizer_d, lr_scheduler_d, epoch, visualize if args.debug else None, args) # evaluate on validation set confmat = validate(val_target_loader, model, interp_val, criterion, None, args) print(confmat.format(train_source_dataset.classes)) acc_global, acc, iu = confmat.compute() # calculate the mean iou over partial classes indexes = [ train_source_dataset.classes.index(name) for name in train_source_dataset.evaluate_classes ] iu = iu[indexes] mean_iou = iu.mean() # remember best acc@1 and save checkpoint torch.save( { 'model': model.state_dict(), 'discriminator': discriminator.state_dict(), 'optimizer': optimizer.state_dict(), 'optimizer_d': optimizer_d.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'lr_scheduler_d': lr_scheduler_d.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path(epoch)) if mean_iou > best_iou: shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best')) best_iou = max(best_iou, mean_iou) print("Target: {} Best: {}".format(mean_iou, best_iou)) logger.close()
def main(args): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = T.Compose([ T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) source_dataset = datasets.__dict__[args.source] train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) target_dataset = datasets.__dict__[args.target] train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # define networks (both generators and discriminators) netG_S2T = cyclegan.generator.__dict__[args.netG]( ngf=args.ngf, norm=args.norm, use_dropout=False).to(device) netG_T2S = cyclegan.generator.__dict__[args.netG]( ngf=args.ngf, norm=args.norm, use_dropout=False).to(device) netD_S = cyclegan.discriminator.__dict__[args.netD]( ndf=args.ndf, norm=args.norm).to(device) netD_T = cyclegan.discriminator.__dict__[args.netD]( ndf=args.ndf, norm=args.norm).to(device) # create image buffer to store previously generated images fake_S_pool = ImagePool(args.pool_size) fake_T_pool = ImagePool(args.pool_size) # define optimizer and lr scheduler optimizer_G = Adam(itertools.chain(netG_S2T.parameters(), netG_T2S.parameters()), lr=args.lr, betas=(args.beta1, 0.999)) optimizer_D = Adam(itertools.chain(netD_S.parameters(), netD_T.parameters()), lr=args.lr, betas=(args.beta1, 0.999)) lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs ) / float(args.epochs_decay) lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function) lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function) # optionally resume from a checkpoint if args.resume: print("Resume from", args.resume) checkpoint = torch.load(args.resume, map_location='cpu') netG_S2T.load_state_dict(checkpoint['netG_S2T']) netG_T2S.load_state_dict(checkpoint['netG_T2S']) netD_S.load_state_dict(checkpoint['netD_S']) netD_T.load_state_dict(checkpoint['netD_T']) optimizer_G.load_state_dict(checkpoint['optimizer_G']) optimizer_D.load_state_dict(checkpoint['optimizer_D']) lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G']) lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D']) args.start_epoch = checkpoint['epoch'] + 1 if args.phase == 'test': transform = T.Compose([ T.Resize(image_size=args.test_input_size), T.wrapper(cyclegan.transform.Translation)(netG_S2T, device), ]) train_source_dataset.translate(transform, args.translated_root) return # define loss function criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss() criterion_cycle = nn.L1Loss() criterion_identity = nn.L1Loss() # define visualization function tensor_to_image = Compose( [Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ToPILImage()]) def visualize(image, name): """ Args: image (tensor): image in shape 3 x H x W name: name of the saving image """ tensor_to_image(image).save( logger.get_image_path("{}.png".format(name))) # start training for epoch in range(args.start_epoch, args.epochs + args.epochs_decay): logger.set_epoch(epoch) print(lr_scheduler_G.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T, criterion_gan, criterion_cycle, criterion_identity, optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch, visualize, args) # update learning rates lr_scheduler_G.step() lr_scheduler_D.step() # save checkpoint torch.save( { 'netG_S2T': netG_S2T.state_dict(), 'netG_T2S': netG_T2S.state_dict(), 'netD_S': netD_S.state_dict(), 'netD_T': netD_T.state_dict(), 'optimizer_G': optimizer_G.state_dict(), 'optimizer_D': optimizer_D.state_dict(), 'lr_scheduler_G': lr_scheduler_G.state_dict(), 'lr_scheduler_D': lr_scheduler_D.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path(epoch)) if args.translated_root is not None: transform = T.Compose([ T.Resize(image_size=args.test_input_size), T.wrapper(cyclegan.transform.Translation)(netG_S2T, device), ]) train_source_dataset.translate(transform, args.translated_root) logger.close()