Example #1
0
def main_worker(gpu, parallel, args, result_dir):
    if parallel:
        args.rank = args.rank + gpu
        torch.distributed.init_process_group(backend='nccl',
                                             init_method=args.dist_url,
                                             world_size=args.world_size,
                                             rank=args.rank)
    torch.backends.cudnn.benchmark = True
    random_seed(args.seed +
                args.rank)  # make data aug different for different processes
    torch.cuda.set_device(gpu)

    assert args.batch_size % args.world_size == 0
    from dataset import load_data, get_statistics, default_eps, input_dim
    train_loader, test_loader = load_data(args.dataset,
                                          'data/',
                                          args.batch_size // args.world_size,
                                          parallel,
                                          augmentation=True,
                                          classes=None)
    mean, std = get_statistics(args.dataset)
    num_classes = len(train_loader.dataset.classes)

    from model.bound_module import Predictor, BoundFinalIdentity
    from model.mlp import MLPFeature, MLP
    from model.conv import ConvFeature, Conv
    model_name, params = parse_function_call(args.model)
    if args.predictor_hidden_size > 0:
        model = locals()[model_name](input_dim=input_dim[args.dataset],
                                     **params)
        predictor = Predictor(model.out_features, args.predictor_hidden_size,
                              num_classes)
    else:
        model = locals()[model_name](input_dim=input_dim[args.dataset],
                                     num_classes=num_classes,
                                     **params)
        predictor = BoundFinalIdentity()
    model = Model(model, predictor, eps=0)
    model = model.cuda(gpu)
    if parallel:
        model = torch.nn.parallel.DistributedDataParallel(model,
                                                          device_ids=[gpu])

    loss_name, params = parse_function_call(args.loss)
    loss = Loss(globals()[loss_name](**params), args.kappa)

    output_flag = not parallel or gpu == 0
    if output_flag:
        logger = Logger(os.path.join(result_dir, 'log.txt'))
        for arg in vars(args):
            logger.print(arg, '=', getattr(args, arg))
        logger.print(train_loader.dataset.transform)
        logger.print(model)
        logger.print('number of params: ',
                     sum([p.numel() for p in model.parameters()]))
        logger.print('Using loss', loss)
        train_logger = TableLogger(os.path.join(result_dir, 'train.log'),
                                   ['epoch', 'loss', 'acc'])
        test_logger = TableLogger(os.path.join(result_dir, 'test.log'),
                                  ['epoch', 'loss', 'acc'])
    else:
        logger = train_logger = test_logger = None

    optimizer = AdamW(model,
                      lr=args.lr,
                      weight_decay=args.wd,
                      betas=(args.beta1, args.beta2),
                      eps=args.epsilon)

    if args.checkpoint:
        assert os.path.isfile(args.checkpoint)
        if parallel:
            torch.distributed.barrier()
        checkpoint = torch.load(
            args.checkpoint,
            map_location=lambda storage, loc: storage.cuda(gpu))
        state_dict = checkpoint['state_dict']
        if next(iter(state_dict))[0:7] == 'module.' and not parallel:
            new_state_dict = OrderedDict([(k[7:], v)
                                          for k, v in state_dict.items()])
            state_dict = new_state_dict
        elif next(iter(state_dict))[0:7] != 'module.' and parallel:
            new_state_dict = OrderedDict([('module.' + k, v)
                                          for k, v in state_dict.items()])
            state_dict = new_state_dict
        model.load_state_dict(state_dict)
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded '{}'".format(args.checkpoint))
        if parallel:
            torch.distributed.barrier()

    if args.eps_test is None:
        args.eps_test = default_eps[args.dataset]
    if args.eps_train is None:
        args.eps_train = args.eps_test
    args.eps_train /= std
    args.eps_test /= std
    up = torch.FloatTensor((1 - mean) / std).view(-1, 1, 1).cuda(gpu)
    down = torch.FloatTensor((0 - mean) / std).view(-1, 1, 1).cuda(gpu)
    attacker = AttackPGD(model,
                         args.eps_test,
                         step_size=args.eps_test / 4,
                         num_steps=20,
                         up=up,
                         down=down)
    args.epochs = [int(epoch) for epoch in args.epochs.split(',')]
    schedule = create_schedule(args, len(train_loader), model, loss, optimizer)

    if args.visualize and output_flag:
        from torch.utils.tensorboard import SummaryWriter
        writer = SummaryWriter(result_dir)
    else:
        writer = None

    for epoch in range(args.start_epoch, args.epochs[-1]):
        if parallel:
            train_loader.sampler.set_epoch(epoch)
        train_loss, train_acc = train(model, loss, epoch, train_loader,
                                      optimizer, schedule, logger,
                                      train_logger, gpu, parallel,
                                      args.print_freq)
        test_loss, test_acc = test(model, loss, epoch, test_loader, logger,
                                   test_logger, gpu, parallel, args.print_freq)
        if writer is not None:
            writer.add_scalar('curve/p', get_p_norm(model), epoch)
            writer.add_scalar('curve/train loss', train_loss, epoch)
            writer.add_scalar('curve/test loss', test_loss, epoch)
            writer.add_scalar('curve/train acc', train_acc, epoch)
            writer.add_scalar('curve/test acc', test_acc, epoch)
        if epoch % 50 == 49:
            if logger is not None:
                logger.print(
                    'Generate adversarial examples on training dataset and test dataset (fast, inaccurate)'
                )
            robust_train_acc = gen_adv_examples(model,
                                                attacker,
                                                train_loader,
                                                gpu,
                                                parallel,
                                                logger,
                                                fast=True)
            robust_test_acc = gen_adv_examples(model,
                                               attacker,
                                               test_loader,
                                               gpu,
                                               parallel,
                                               logger,
                                               fast=True)
            if writer is not None:
                writer.add_scalar('curve/robust train acc', robust_train_acc,
                                  epoch)
                writer.add_scalar('curve/robust test acc', robust_test_acc,
                                  epoch)
        if epoch % 5 == 4:
            certified_acc = certified_test(model, args.eps_test, up, down,
                                           epoch, test_loader, logger, gpu,
                                           parallel)
            if writer is not None:
                writer.add_scalar('curve/certified acc', certified_acc, epoch)
        if epoch > args.epochs[-1] - 3:
            if logger is not None:
                logger.print("Generate adversarial examples on test dataset")
            gen_adv_examples(model, attacker, test_loader, gpu, parallel,
                             logger)
            certified_test(model, args.eps_test, up, down, epoch, test_loader,
                           logger, gpu, parallel)

    schedule(args.epochs[-1], 0)
    if output_flag:
        logger.print(
            "Calculate certified accuracy on training dataset and test dataset"
        )
    certified_test(model, args.eps_test, up, down, args.epochs[-1],
                   train_loader, logger, gpu, parallel)
    certified_test(model, args.eps_test, up, down, args.epochs[-1],
                   test_loader, logger, gpu, parallel)

    if output_flag:
        torch.save(
            {
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(result_dir, 'model.pth'))
    if writer is not None:
        writer.close()
Example #2
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    torch.backends.cudnn.deterministic = True

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # import network
    gen_net = eval('models.' + args.gen_model + '.Generator')(args=args).cuda()
    dis_net = eval('models.' + args.dis_model +
                   '.Discriminator')(args=args).cuda()
    gen_net.set_arch(args.arch, cur_stage=2)

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform_(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net.apply(weights_init)
    dis_net.apply(weights_init)

    gpu_ids = [i for i in range(int(torch.cuda.device_count()))]
    gen_net = torch.nn.DataParallel(gen_net.to("cuda:0"), device_ids=gpu_ids)
    dis_net = torch.nn.DataParallel(dis_net.to("cuda:0"), device_ids=gpu_ids)

    gen_net.module.cur_stage = 0
    dis_net.module.cur_stage = 0
    gen_net.module.alpha = 1.
    dis_net.module.alpha = 1.

    # set optimizer
    if args.optimizer == "adam":
        gen_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr,
            (args.beta1, args.beta2))
        dis_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr,
            (args.beta1, args.beta2))
    elif args.optimizer == "adamw":
        gen_optimizer = AdamW(filter(lambda p: p.requires_grad,
                                     gen_net.parameters()),
                              args.g_lr,
                              weight_decay=args.wd)
        dis_optimizer = AdamW(filter(lambda p: p.requires_grad,
                                     dis_net.parameters()),
                              args.g_lr,
                              weight_decay=args.wd)
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
    elif args.fid_stat is not None:
        fid_stat = args.fid_stat
    else:
        raise NotImplementedError(f'no fid stat for {args.dataset.lower()}')
    assert os.path.exists(fid_stat)

    # epoch number for dis_net
    args.max_epoch = args.max_epoch * args.n_critic
    dataset = datasets.ImageDataset(args, cur_img_size=8)
    train_loader = dataset.train
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter * args.n_critic /
                                 len(train_loader))

    # initial
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (64, args.latent_dim)))
    gen_avg_param = copy_params(gen_net)
    start_epoch = 0
    best_fid = 1e4

    # set writer
    if args.load_path:
        print(f'=> resuming from {args.load_path}')
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path)
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        best_fid = checkpoint['best_fid']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        avg_gen_net = deepcopy(gen_net)
        avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])
        gen_avg_param = copy_params(avg_gen_net)
        del avg_gen_net
        cur_stage = cur_stages(start_epoch, args)
        gen_net.module.cur_stage = cur_stage
        dis_net.module.cur_stage = cur_stage
        gen_net.module.alpha = 1.
        dis_net.module.alpha = 1.

        # args.path_helper = checkpoint['path_helper']

    else:
        # create new log dir
        assert args.exp_name
    args.path_helper = set_log_dir('logs', args.exp_name)
    logger = create_logger(args.path_helper['log_path'])

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    # train loop
    for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)),
                      desc='total progress'):
        lr_schedulers = (gen_scheduler,
                         dis_scheduler) if args.lr_decay else None
        train(args, gen_net, dis_net, gen_optimizer, dis_optimizer,
              gen_avg_param, train_loader, epoch, writer_dict, lr_schedulers)

        if epoch and epoch % args.val_freq == 0 or epoch == int(
                args.max_epoch) - 1:
            backup_param = copy_params(gen_net)
            load_params(gen_net, gen_avg_param)
            inception_score, fid_score = validate(args, fixed_z, fid_stat,
                                                  epoch, gen_net, writer_dict)
            logger.info(
                f'Inception score: {inception_score}, FID score: {fid_score} || @ epoch {epoch}.'
            )
            load_params(gen_net, backup_param)
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        avg_gen_net = deepcopy(gen_net)
        load_params(avg_gen_net, gen_avg_param)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'gen_model': args.gen_model,
                'dis_model': args.dis_model,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'avg_gen_state_dict': avg_gen_net.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'best_fid': best_fid,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'])
        del avg_gen_net
Example #3
0
        assert(old_opt.lr == opt.lr)
        assert(old_opt.decay == opt.decay)
        assert(old_opt.period == opt.period)
        assert(old_opt.t_mult == opt.t_mult)
        net.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        vis.load_state_dict(checkpoint['vis'])
        start_epoch = checkpoint['epoch'] + 1
    elif opt.pretrain is not None:
        checkpoint = torch.load(opt.pretrain)
        old_opt = checkpoint['opt']
        #assert(old_opt.channels == opt.channels)
        #assert(old_opt.bands == opt.bands)
        assert(old_opt.arch == opt.arch)
        assert(old_opt.blend == opt.blend)
        net.load_state_dict(checkpoint['state_dict'])
    else:
        assert(False)

    for epoch in range(start_epoch, opt.n_epochs):
        train(opt, vis, epoch, train_loader, net, optimizer, scheduler)
        miou_val = test(opt, epoch, val_loader, net)
        miou_test = test(opt, epoch, test_loader, net)
        vis.epoch.append(epoch)
        vis.acc.append([miou_val, miou_test])
        vis.plot_acc()
        if (epoch + 1) % opt.period == 0:
            torch.save({'epoch': epoch, 'opt': opt, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(),  'scheduler': scheduler.state_dict(), 'vis': vis.state_dict()}, Path(opt.out_path) / (str(epoch) + '.pth'))
        print('Val mIoU:', miou_val, ' Test mIoU:', miou_test)
Example #4
0
    elif opt.pretrain is not None:
        checkpoint = torch.load(opt.pretrain)
        old_opt = checkpoint['opt']
        assert (old_opt.channels == opt.channels)
        assert (old_opt.bands == opt.bands)
        assert (old_opt.arch == opt.arch)
        assert (old_opt.blend == opt.blend)
        net.load_state_dict(checkpoint['state_dict'])
    else:
        assert (False)

    for epoch in range(start_epoch, opt.n_epochs):
        train(opt, vis, epoch, train_loader, net, optimizer, scheduler)
        miou_val = test(opt, epoch, val_loader, net)
        miou_test = test(opt, epoch, test_loader, net)
        vis.epoch.append(epoch)
        vis.acc.append([miou_val, miou_test])
        vis.plot_acc()
        if (epoch + 1) % opt.period == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'opt': opt,
                    'state_dict': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'vis': vis.state_dict()
                },
                Path(opt.out_path) / (str(epoch) + '.pth'))
        print('Val mIoU:', miou_val, ' Test mIoU:', miou_test)