def main_worker(gpu, parallel, args, result_dir): if parallel: args.rank = args.rank + gpu torch.distributed.init_process_group(backend='nccl', init_method=args.dist_url, world_size=args.world_size, rank=args.rank) torch.backends.cudnn.benchmark = True random_seed(args.seed + args.rank) # make data aug different for different processes torch.cuda.set_device(gpu) assert args.batch_size % args.world_size == 0 from dataset import load_data, get_statistics, default_eps, input_dim train_loader, test_loader = load_data(args.dataset, 'data/', args.batch_size // args.world_size, parallel, augmentation=True, classes=None) mean, std = get_statistics(args.dataset) num_classes = len(train_loader.dataset.classes) from model.bound_module import Predictor, BoundFinalIdentity from model.mlp import MLPFeature, MLP from model.conv import ConvFeature, Conv model_name, params = parse_function_call(args.model) if args.predictor_hidden_size > 0: model = locals()[model_name](input_dim=input_dim[args.dataset], **params) predictor = Predictor(model.out_features, args.predictor_hidden_size, num_classes) else: model = locals()[model_name](input_dim=input_dim[args.dataset], num_classes=num_classes, **params) predictor = BoundFinalIdentity() model = Model(model, predictor, eps=0) model = model.cuda(gpu) if parallel: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) loss_name, params = parse_function_call(args.loss) loss = Loss(globals()[loss_name](**params), args.kappa) output_flag = not parallel or gpu == 0 if output_flag: logger = Logger(os.path.join(result_dir, 'log.txt')) for arg in vars(args): logger.print(arg, '=', getattr(args, arg)) logger.print(train_loader.dataset.transform) logger.print(model) logger.print('number of params: ', sum([p.numel() for p in model.parameters()])) logger.print('Using loss', loss) train_logger = TableLogger(os.path.join(result_dir, 'train.log'), ['epoch', 'loss', 'acc']) test_logger = TableLogger(os.path.join(result_dir, 'test.log'), ['epoch', 'loss', 'acc']) else: logger = train_logger = test_logger = None optimizer = AdamW(model, lr=args.lr, weight_decay=args.wd, betas=(args.beta1, args.beta2), eps=args.epsilon) if args.checkpoint: assert os.path.isfile(args.checkpoint) if parallel: torch.distributed.barrier() checkpoint = torch.load( args.checkpoint, map_location=lambda storage, loc: storage.cuda(gpu)) state_dict = checkpoint['state_dict'] if next(iter(state_dict))[0:7] == 'module.' and not parallel: new_state_dict = OrderedDict([(k[7:], v) for k, v in state_dict.items()]) state_dict = new_state_dict elif next(iter(state_dict))[0:7] != 'module.' and parallel: new_state_dict = OrderedDict([('module.' + k, v) for k, v in state_dict.items()]) state_dict = new_state_dict model.load_state_dict(state_dict) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded '{}'".format(args.checkpoint)) if parallel: torch.distributed.barrier() if args.eps_test is None: args.eps_test = default_eps[args.dataset] if args.eps_train is None: args.eps_train = args.eps_test args.eps_train /= std args.eps_test /= std up = torch.FloatTensor((1 - mean) / std).view(-1, 1, 1).cuda(gpu) down = torch.FloatTensor((0 - mean) / std).view(-1, 1, 1).cuda(gpu) attacker = AttackPGD(model, args.eps_test, step_size=args.eps_test / 4, num_steps=20, up=up, down=down) args.epochs = [int(epoch) for epoch in args.epochs.split(',')] schedule = create_schedule(args, len(train_loader), model, loss, optimizer) if args.visualize and output_flag: from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter(result_dir) else: writer = None for epoch in range(args.start_epoch, args.epochs[-1]): if parallel: train_loader.sampler.set_epoch(epoch) train_loss, train_acc = train(model, loss, epoch, train_loader, optimizer, schedule, logger, train_logger, gpu, parallel, args.print_freq) test_loss, test_acc = test(model, loss, epoch, test_loader, logger, test_logger, gpu, parallel, args.print_freq) if writer is not None: writer.add_scalar('curve/p', get_p_norm(model), epoch) writer.add_scalar('curve/train loss', train_loss, epoch) writer.add_scalar('curve/test loss', test_loss, epoch) writer.add_scalar('curve/train acc', train_acc, epoch) writer.add_scalar('curve/test acc', test_acc, epoch) if epoch % 50 == 49: if logger is not None: logger.print( 'Generate adversarial examples on training dataset and test dataset (fast, inaccurate)' ) robust_train_acc = gen_adv_examples(model, attacker, train_loader, gpu, parallel, logger, fast=True) robust_test_acc = gen_adv_examples(model, attacker, test_loader, gpu, parallel, logger, fast=True) if writer is not None: writer.add_scalar('curve/robust train acc', robust_train_acc, epoch) writer.add_scalar('curve/robust test acc', robust_test_acc, epoch) if epoch % 5 == 4: certified_acc = certified_test(model, args.eps_test, up, down, epoch, test_loader, logger, gpu, parallel) if writer is not None: writer.add_scalar('curve/certified acc', certified_acc, epoch) if epoch > args.epochs[-1] - 3: if logger is not None: logger.print("Generate adversarial examples on test dataset") gen_adv_examples(model, attacker, test_loader, gpu, parallel, logger) certified_test(model, args.eps_test, up, down, epoch, test_loader, logger, gpu, parallel) schedule(args.epochs[-1], 0) if output_flag: logger.print( "Calculate certified accuracy on training dataset and test dataset" ) certified_test(model, args.eps_test, up, down, args.epochs[-1], train_loader, logger, gpu, parallel) certified_test(model, args.eps_test, up, down, args.epochs[-1], test_loader, logger, gpu, parallel) if output_flag: torch.save( { 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, os.path.join(result_dir, 'model.pth')) if writer is not None: writer.close()
def main(): args = cfg.parse_args() torch.cuda.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) torch.backends.cudnn.deterministic = True # set tf env _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # import network gen_net = eval('models.' + args.gen_model + '.Generator')(args=args).cuda() dis_net = eval('models.' + args.dis_model + '.Discriminator')(args=args).cuda() gen_net.set_arch(args.arch, cur_stage=2) # weight init def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: if args.init_type == 'normal': nn.init.normal_(m.weight.data, 0.0, 0.02) elif args.init_type == 'orth': nn.init.orthogonal_(m.weight.data) elif args.init_type == 'xavier_uniform': nn.init.xavier_uniform_(m.weight.data, 1.) else: raise NotImplementedError('{} unknown inital type'.format( args.init_type)) elif classname.find('BatchNorm2d') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) gen_net.apply(weights_init) dis_net.apply(weights_init) gpu_ids = [i for i in range(int(torch.cuda.device_count()))] gen_net = torch.nn.DataParallel(gen_net.to("cuda:0"), device_ids=gpu_ids) dis_net = torch.nn.DataParallel(dis_net.to("cuda:0"), device_ids=gpu_ids) gen_net.module.cur_stage = 0 dis_net.module.cur_stage = 0 gen_net.module.alpha = 1. dis_net.module.alpha = 1. # set optimizer if args.optimizer == "adam": gen_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr, (args.beta1, args.beta2)) dis_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr, (args.beta1, args.beta2)) elif args.optimizer == "adamw": gen_optimizer = AdamW(filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr, weight_decay=args.wd) dis_optimizer = AdamW(filter(lambda p: p.requires_grad, dis_net.parameters()), args.g_lr, weight_decay=args.wd) gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic) dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic) # fid stat if args.dataset.lower() == 'cifar10': fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' elif args.dataset.lower() == 'stl10': fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz' elif args.fid_stat is not None: fid_stat = args.fid_stat else: raise NotImplementedError(f'no fid stat for {args.dataset.lower()}') assert os.path.exists(fid_stat) # epoch number for dis_net args.max_epoch = args.max_epoch * args.n_critic dataset = datasets.ImageDataset(args, cur_img_size=8) train_loader = dataset.train if args.max_iter: args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(train_loader)) # initial fixed_z = torch.cuda.FloatTensor( np.random.normal(0, 1, (64, args.latent_dim))) gen_avg_param = copy_params(gen_net) start_epoch = 0 best_fid = 1e4 # set writer if args.load_path: print(f'=> resuming from {args.load_path}') assert os.path.exists(args.load_path) checkpoint_file = os.path.join(args.load_path) assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint['epoch'] best_fid = checkpoint['best_fid'] gen_net.load_state_dict(checkpoint['gen_state_dict']) dis_net.load_state_dict(checkpoint['dis_state_dict']) gen_optimizer.load_state_dict(checkpoint['gen_optimizer']) dis_optimizer.load_state_dict(checkpoint['dis_optimizer']) avg_gen_net = deepcopy(gen_net) avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict']) gen_avg_param = copy_params(avg_gen_net) del avg_gen_net cur_stage = cur_stages(start_epoch, args) gen_net.module.cur_stage = cur_stage dis_net.module.cur_stage = cur_stage gen_net.module.alpha = 1. dis_net.module.alpha = 1. # args.path_helper = checkpoint['path_helper'] else: # create new log dir assert args.exp_name args.path_helper = set_log_dir('logs', args.exp_name) logger = create_logger(args.path_helper['log_path']) logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'train_global_steps': start_epoch * len(train_loader), 'valid_global_steps': start_epoch // args.val_freq, } # train loop for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)), desc='total progress'): lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None train(args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict, lr_schedulers) if epoch and epoch % args.val_freq == 0 or epoch == int( args.max_epoch) - 1: backup_param = copy_params(gen_net) load_params(gen_net, gen_avg_param) inception_score, fid_score = validate(args, fixed_z, fid_stat, epoch, gen_net, writer_dict) logger.info( f'Inception score: {inception_score}, FID score: {fid_score} || @ epoch {epoch}.' ) load_params(gen_net, backup_param) if fid_score < best_fid: best_fid = fid_score is_best = True else: is_best = False else: is_best = False avg_gen_net = deepcopy(gen_net) load_params(avg_gen_net, gen_avg_param) save_checkpoint( { 'epoch': epoch + 1, 'gen_model': args.gen_model, 'dis_model': args.dis_model, 'gen_state_dict': gen_net.state_dict(), 'dis_state_dict': dis_net.state_dict(), 'avg_gen_state_dict': avg_gen_net.state_dict(), 'gen_optimizer': gen_optimizer.state_dict(), 'dis_optimizer': dis_optimizer.state_dict(), 'best_fid': best_fid, 'path_helper': args.path_helper }, is_best, args.path_helper['ckpt_path']) del avg_gen_net
assert(old_opt.lr == opt.lr) assert(old_opt.decay == opt.decay) assert(old_opt.period == opt.period) assert(old_opt.t_mult == opt.t_mult) net.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) vis.load_state_dict(checkpoint['vis']) start_epoch = checkpoint['epoch'] + 1 elif opt.pretrain is not None: checkpoint = torch.load(opt.pretrain) old_opt = checkpoint['opt'] #assert(old_opt.channels == opt.channels) #assert(old_opt.bands == opt.bands) assert(old_opt.arch == opt.arch) assert(old_opt.blend == opt.blend) net.load_state_dict(checkpoint['state_dict']) else: assert(False) for epoch in range(start_epoch, opt.n_epochs): train(opt, vis, epoch, train_loader, net, optimizer, scheduler) miou_val = test(opt, epoch, val_loader, net) miou_test = test(opt, epoch, test_loader, net) vis.epoch.append(epoch) vis.acc.append([miou_val, miou_test]) vis.plot_acc() if (epoch + 1) % opt.period == 0: torch.save({'epoch': epoch, 'opt': opt, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'vis': vis.state_dict()}, Path(opt.out_path) / (str(epoch) + '.pth')) print('Val mIoU:', miou_val, ' Test mIoU:', miou_test)
elif opt.pretrain is not None: checkpoint = torch.load(opt.pretrain) old_opt = checkpoint['opt'] assert (old_opt.channels == opt.channels) assert (old_opt.bands == opt.bands) assert (old_opt.arch == opt.arch) assert (old_opt.blend == opt.blend) net.load_state_dict(checkpoint['state_dict']) else: assert (False) for epoch in range(start_epoch, opt.n_epochs): train(opt, vis, epoch, train_loader, net, optimizer, scheduler) miou_val = test(opt, epoch, val_loader, net) miou_test = test(opt, epoch, test_loader, net) vis.epoch.append(epoch) vis.acc.append([miou_val, miou_test]) vis.plot_acc() if (epoch + 1) % opt.period == 0: torch.save( { 'epoch': epoch, 'opt': opt, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'vis': vis.state_dict() }, Path(opt.out_path) / (str(epoch) + '.pth')) print('Val mIoU:', miou_val, ' Test mIoU:', miou_test)