def main():
    # model = torch.nn.DataParallel(ResNet18(num_classes=args.num_classes).to(device))
    model = torch.nn.DataParallel(
        ImplicitResNet18(num_classes=args.num_classes).to(device))
    with torch.no_grad():
        x, _ = next(iter(train_loader))
        x = x.to(device)
        model(x)
    ema = utils.ExponentialMovingAverage(model)
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           betas=(0.9, 0.99),
                           weight_decay=args.weight_decay)
    # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    logger.info(model)
    logger.info('EMA: {}'.format(ema))
    logger.info(optimizer)

    for epoch in range(1, args.epochs + 1):
        # adjust learning rate for SGD
        adjust_learning_rate(optimizer, epoch)

        # adversarial training
        train(args, model, device, train_loader, optimizer, epoch, ema)

        # evaluation on natural examples
        logger.info(
            '================================================================')
        eval_train(model, device, train_loader, ema)
        eval_test(model, device, test_loader, ema)
        logger.info(
            '================================================================')

        # save checkpoint
        if epoch == args.epochs:
            torch.save(
                model.state_dict(),
                os.path.join(model_dir,
                             'model-wideres-epoch{}.pt'.format(epoch)))
            torch.save(
                optimizer.state_dict(),
                os.path.join(
                    model_dir,
                    'opt-wideres-checkpoint_epoch{}.tar'.format(epoch)))
Example #2
0
    fc_end=args.fc_end,
    fc_idim=args.fc_idim,
    n_exact_terms=args.n_exact_terms,
    preact=args.preact,
    neumann_grad=args.neumann_grad,
    grad_in_forward=args.mem_eff,
    first_resblock=args.first_resblock,
    learn_p=args.learn_p,
    classification=args.task in ['classification', 'hybrid'],
    classification_hdim=args.cdim,
    n_classes=n_classes,
    block_type=args.block,
)

model.to(device)
ema = utils.ExponentialMovingAverage(model)


def parallelize(model):
    return torch.nn.DataParallel(model)


logger.info(model)
logger.info('EMA: {}'.format(ema))


# Optimization
def tensor_in(t, a):
    for a_ in a:
        if t is a_:
            return True
Example #3
0
def main(rank, world_size, args):
    setup(rank, world_size, args.port)

    # setup logger
    if rank == 0:
        utils.makedirs(args.save)
        logger = utils.get_logger(os.path.join(args.save, "logs"))

    def mprint(msg):
        if rank == 0:
            logger.info(msg)

    mprint(args)

    device = torch.device(
        f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')

    if device.type == 'cuda':
        mprint('Found {} CUDA devices.'.format(torch.cuda.device_count()))
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            mprint('{} \t Memory: {:.2f}GB'.format(
                props.name, props.total_memory / (1024**3)))
    else:
        mprint('WARNING: Using device {}'.format(device))

    np.random.seed(args.seed + rank)
    torch.manual_seed(args.seed + rank)
    if device.type == 'cuda':
        torch.cuda.manual_seed(args.seed + rank)

    mprint('Loading dataset {}'.format(args.data))
    # Dataset and hyperparameters
    if args.data == 'cifar10':
        im_dim = 3

        transform_train = transforms.Compose([
            transforms.Resize(args.imagesize),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            add_noise if args.add_noise else identity,
        ])
        transform_test = transforms.Compose([
            transforms.Resize(args.imagesize),
            transforms.ToTensor(),
            add_noise if args.add_noise else identity,
        ])

        init_layer = flows.LogitTransform(0.05)
        train_set = vdsets.SVHN(args.dataroot,
                                download=True,
                                split="train",
                                transform=transform_train)
        sampler = torch.utils.data.distributed.DistributedSampler(train_set)
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batchsize,
            sampler=sampler,
        )
        test_loader = torch.utils.data.DataLoader(
            vdsets.SVHN(args.dataroot,
                        download=True,
                        split="test",
                        transform=transform_test),
            batch_size=args.val_batchsize,
            shuffle=False,
        )

    elif args.data == 'mnist':
        im_dim = 1
        init_layer = flows.LogitTransform(1e-6)
        train_set = datasets.MNIST(
            args.dataroot,
            train=True,
            transform=transforms.Compose([
                transforms.Resize(args.imagesize),
                transforms.ToTensor(),
                add_noise if args.add_noise else identity,
            ]))
        sampler = torch.utils.data.distributed.DistributedSampler(train_set)
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batchsize,
            sampler=sampler,
        )
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.dataroot,
                           train=False,
                           transform=transforms.Compose([
                               transforms.Resize(args.imagesize),
                               transforms.ToTensor(),
                               add_noise if args.add_noise else identity,
                           ])),
            batch_size=args.val_batchsize,
            shuffle=False,
        )
    else:
        raise Exception(f'dataset not one of mnist / cifar10, got {args.data}')

    mprint('Dataset loaded.')
    mprint('Creating model.')

    input_size = (args.batchsize, im_dim, args.imagesize, args.imagesize)

    model = MultiscaleFlow(
        input_size,
        block_fn=partial(cpflow_block_fn,
                         block_type=args.block_type,
                         dimh=args.dimh,
                         num_hidden_layers=args.num_hidden_layers,
                         icnn_version=args.icnn,
                         num_pooling=args.num_pooling),
        n_blocks=list(map(int, args.nblocks.split('-'))),
        factor_out=args.factor_out,
        init_layer=init_layer,
        actnorm=args.actnorm,
        fc_end=args.fc_end,
        glow=args.glow,
    )
    model.to(device)

    model = DDP(model, device_ids=[rank], find_unused_parameters=True)
    ema = utils.ExponentialMovingAverage(model)

    mprint(model)
    mprint('EMA: {}'.format(ema))

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           betas=(0.9, 0.99),
                           weight_decay=args.wd)

    # Saving and resuming
    best_test_bpd = math.inf
    begin_epoch = 0

    most_recent_path = os.path.join(args.save, 'models', 'most_recent.pth')
    checkpt_exists = os.path.exists(most_recent_path)
    if checkpt_exists:
        mprint(f"Resuming from {most_recent_path}")

        # deal with data-dependent initialization like actnorm.
        with torch.no_grad():
            x = torch.rand(8, *input_size[1:]).to(device)
            model(x)

        checkpt = torch.load(most_recent_path)
        begin_epoch = checkpt["epoch"] + 1

        model.module.load_state_dict(checkpt["state_dict"])
        ema.set(checkpt['ema'])
        optimizer.load_state_dict(checkpt["opt_state_dict"])
    elif args.resume:
        mprint(f"Resuming from {args.resume}")

        # deal with data-dependent initialization like actnorm.
        with torch.no_grad():
            x = torch.rand(8, *input_size[1:]).to(device)
            model(x)

        checkpt = torch.load(args.resume)
        begin_epoch = checkpt["epoch"] + 1

        model.module.load_state_dict(checkpt["state_dict"])
        ema.set(checkpt['ema'])
        optimizer.load_state_dict(checkpt["opt_state_dict"])

    mprint(optimizer)

    batch_time = utils.RunningAverageMeter(0.97)
    bpd_meter = utils.RunningAverageMeter(0.97)
    gnorm_meter = utils.RunningAverageMeter(0.97)
    cg_meter = utils.RunningAverageMeter(0.97)
    hnorm_meter = utils.RunningAverageMeter(0.97)

    update_lr(optimizer, 0, args)

    # for visualization
    fixed_x = next(iter(train_loader))[0][:8].to(device)
    fixed_z = torch.randn(8,
                          im_dim * args.imagesize * args.imagesize).to(fixed_x)
    if rank == 0:
        utils.makedirs(os.path.join(args.save, 'figs'))
        # visualize(model, fixed_x, fixed_z, os.path.join(args.save, 'figs', 'init.png'))
    for epoch in range(begin_epoch, args.nepochs):
        sampler.set_epoch(epoch)
        flows.CG_ITERS_TRACER.clear()
        flows.HESS_NORM_TRACER.clear()
        mprint('Current LR {}'.format(optimizer.param_groups[0]['lr']))
        train(epoch, train_loader, model, optimizer, bpd_meter, gnorm_meter,
              cg_meter, hnorm_meter, batch_time, ema, device, mprint,
              world_size, args)
        val_time, test_bpd = validate(epoch, model, test_loader, ema, device)
        mprint(
            'Epoch: [{0}]\tTime {1:.2f} | Test bits/dim {test_bpd:.4f}'.format(
                epoch, val_time, test_bpd=test_bpd))

        if rank == 0:
            utils.makedirs(os.path.join(args.save, 'figs'))
            visualize(model, fixed_x, fixed_z,
                      os.path.join(args.save, 'figs', f'{epoch}.png'))

            utils.makedirs(os.path.join(args.save, "models"))
            if test_bpd < best_test_bpd:
                best_test_bpd = test_bpd
                torch.save(
                    {
                        'epoch': epoch,
                        'state_dict': model.module.state_dict(),
                        'opt_state_dict': optimizer.state_dict(),
                        'args': args,
                        'ema': ema,
                        'test_bpd': test_bpd,
                    }, os.path.join(args.save, 'models', 'best_model.pth'))

        if rank == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model.module.state_dict(),
                    'opt_state_dict': optimizer.state_dict(),
                    'args': args,
                    'ema': ema,
                    'test_bpd': test_bpd,
                }, os.path.join(args.save, 'models', 'most_recent.pth'))

    cleanup()