예제 #1
0
def get_loaders(
    traindir,
    valdir,
    sz,
    bs,
    fp16=True,
    val_bs=None,
    workers=8,
    rect_val=False,
    min_scale=0.08,
    distributed=False,
    synthetic=False,
):
    val_bs = val_bs or bs
    train_tfms = [transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)), transforms.RandomHorizontalFlip()]
    train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_tfms))
    train_sampler = (
        DistributedSampler(train_dataset, num_replicas=env_world_size(), rank=env_rank()) if distributed else None
    )

    if synthetic:
        print("Using synthetic dataloader")
        train_loader = SyntheticDataLoader(bs, (3, sz, sz))
    elif util.is_set("PYTORCH_USE_SPAWN"):
        print("Using SPAWN method for dataloader")
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=bs,
            shuffle=(train_sampler is None),
            num_workers=workers,
            pin_memory=True,
            collate_fn=fast_collate,
            sampler=train_sampler,
            multiprocessing_context="spawn",
        )
    else:
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=bs,
            shuffle=(train_sampler is None),
            num_workers=workers,
            pin_memory=True,
            collate_fn=fast_collate,
            sampler=train_sampler,
        )

    val_dataset, val_sampler = create_validation_set(valdir, val_bs, sz, rect_val=rect_val, distributed=distributed)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, num_workers=workers, pin_memory=True, collate_fn=fast_collate, batch_sampler=val_sampler
    )

    train_loader = BatchTransformDataLoader(train_loader, fp16=fp16)
    val_loader = BatchTransformDataLoader(val_loader, fp16=fp16)

    return train_loader, val_loader, train_sampler, val_sampler
예제 #2
0
    def __init__(self, indices, batch_size, distributed=True):
        self.indices = indices
        self.batch_size = batch_size
        if distributed:
            self.world_size = env_world_size()
            self.global_rank = env_rank()
        else:
            self.global_rank = 0
            self.world_size = 1

        # expected number of batches per sample. Need this so each distributed gpu validates on same number of batches.
        # even if there isn't enough data to go around
        self.expected_num_batches = math.ceil(len(self.indices) / self.world_size / self.batch_size)

        # num_samples = total images / world_size. This is what we distribute to each gpu
        self.num_samples = self.expected_num_batches * self.batch_size
예제 #3
0
def get_loaders(traindir,
                valdir,
                sz,
                bs,
                fp16=False,
                val_bs=None,
                workers=8,
                rect_val=False,
                min_scale=0.08,
                distributed=False):
    val_bs = val_bs or bs
    train_tfms = [
        transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)),
        transforms.RandomHorizontalFlip()
    ]
    train_dataset = datasets.ImageFolder(traindir,
                                         transforms.Compose(train_tfms))
    train_sampler = (DistributedSampler(
        train_dataset, num_replicas=env_world_size(), rank=env_rank())
                     if distributed else None)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=bs,
                                               shuffle=(train_sampler is None),
                                               num_workers=workers,
                                               pin_memory=True,
                                               collate_fn=fast_collate,
                                               sampler=train_sampler)

    val_dataset, val_sampler = create_validation_set(valdir,
                                                     val_bs,
                                                     sz,
                                                     rect_val=rect_val,
                                                     distributed=distributed)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             num_workers=workers,
                                             pin_memory=True,
                                             collate_fn=fast_collate,
                                             batch_sampler=val_sampler)

    train_loader = BatchTransformDataLoader(train_loader, fp16=fp16)
    val_loader = BatchTransformDataLoader(val_loader, fp16=fp16)

    return train_loader, val_loader, train_sampler, val_sampler
def main():
    # os.system('shutdown -c')  # cancel previous shutdown command

    if write_log:
        utils.makedirs(args.save)
        logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))

        logger.info(args)

        args_file_path = os.path.join(args.save, 'args.yaml')
        with open(args_file_path, 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

    if args.distributed:
        if write_log: logger.info('Distributed initializing process group')
        torch.cuda.set_device(args.local_rank)
        distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                       world_size=dist_utils.env_world_size(), rank=env_rank())
        assert (dist_utils.env_world_size() == distributed.get_world_size())
        if write_log: logger.info("Distributed: success (%d/%d)" % (args.local_rank, distributed.get_world_size()))
        device = torch.device("cuda:%d" % torch.cuda.current_device() if torch.cuda.is_available() else "cpu")
    else:
        device = torch.cuda.current_device()  #

    # import pdb; pdb.set_trace()
    cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

    # load dataset
    train_loader, test_loader, data_shape = get_dataset(args)

    trainlog = os.path.join(args.save, 'training.csv')
    testlog = os.path.join(args.save, 'test.csv')

    traincolumns = ['itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time', 'grad_norm']
    testcolumns = ['wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time', 'transport_cost']

    # build model
    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = create_model(args, data_shape, regularization_fns).cuda()
    if args.distributed: model = dist_utils.DDP(model,
                                                device_ids=[args.local_rank],
                                                output_device=args.local_rank)

    traincolumns = append_regularization_keys_header(traincolumns, regularization_fns)

    if not args.resume and write_log:
        with open(trainlog, 'w') as f:
            csvlogger = csv.DictWriter(f, traincolumns)
            csvlogger.writeheader()
        with open(testlog, 'w') as f:
            csvlogger = csv.DictWriter(f, testcolumns)
            csvlogger.writeheader()

    set_cnf_options(args, model)

    if write_log: logger.info(model)
    if write_log: logger.info("Number of trainable parameters: {}".format(count_parameters(model)))
    if write_log: logger.info('Iters per train epoch: {}'.format(len(train_loader)))
    if write_log: logger.info('Iters per test: {}'.format(len(test_loader)))

    # optimizer
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9,
                              nesterov=False)

    # restore parameters
    # import pdb; pdb.set_trace()
    if args.resume is not None:
        # import pdb; pdb.set_trace()
        print('resume from checkpoint')
        checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpt["state_dict"])
        if "optim_state_dict" in checkpt.keys():
            optimizer.load_state_dict(checkpt["optim_state_dict"])
            # Manually move optimizer state to device.
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = cvt(v)

    # For visualization.
    if write_log: fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape))

    if write_log:
        time_meter = utils.RunningAverageMeter(0.97)
        bpd_meter = utils.RunningAverageMeter(0.97)
        loss_meter = utils.RunningAverageMeter(0.97)
        steps_meter = utils.RunningAverageMeter(0.97)
        grad_meter = utils.RunningAverageMeter(0.97)
        tt_meter = utils.RunningAverageMeter(0.97)

    if not args.resume:
        best_loss = float("inf")
        itr = 0
        wall_clock = 0.
        begin_epoch = 1
        chkdir = args.save
        '''
    elif args.resume and args.validate:
        chkdir = os.path.dirname(args.resume)
        wall_clock = 0
        itr = 0
        best_loss = 0.0
        begin_epoch = 0
        '''
    else:
        chkdir = os.path.dirname(args.resume)
        filename = os.path.join(chkdir, 'test.csv')
        print(filename)
        tedf = pd.read_csv(os.path.join(chkdir, 'test.csv'))
        trdf = pd.read_csv(os.path.join(chkdir, 'training.csv'))
        # import pdb; pdb.set_trace()
        wall_clock = trdf['wall'].to_numpy()[-1]
        itr = trdf['itr'].to_numpy()[-1]
        best_loss = tedf['bpd'].min()
        begin_epoch = int(tedf['epoch'].to_numpy()[-1] + 1)  # not exactly correct

    if args.distributed:
        if write_log: logger.info('Syncing machines before training')
        dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())

    for epoch in range(begin_epoch, begin_epoch + 1):
        # compute test loss
        print('Evaluating')
        model.eval()
        if args.local_rank == 0:
            utils.makedirs(args.save)
            # import pdb; pdb.set_trace()
            if hasattr(model, 'module'):
                _state = model.module.state_dict()
            else:
                _state = model.state_dict()
            torch.save({
                "args": args,
                "state_dict": _state,  # model.module.state_dict() if torch.cuda.is_available() else model.state_dict(),
                "optim_state_dict": optimizer.state_dict(),
                "fixed_z": fixed_z.cpu()
            }, os.path.join(args.save, "checkpt_%d.pth" % epoch))

        # save real and generate with different temperatures
        fig_num = 64
        if True:  # args.save_real:
            for i, (x, y) in enumerate(test_loader):
                if i < 100:
                    pass
                elif i == 100:
                    real = x.size(0)
                else:
                    break
            if x.shape[0] > fig_num:
                x = x[:fig_num, ...]
            # import pdb; pdb.set_trace()
            fig_filename = os.path.join(chkdir, "real.jpg")
            save_image(x.float() / 255.0, fig_filename, nrow=8)

        if True:  # args.generate:
            print('\nGenerating images... ')
            fixed_z = cvt(torch.randn(fig_num, *data_shape))
            nb = int(np.ceil(np.sqrt(float(fixed_z.size(0)))))
            for t in [ 1.0, 0.99, 0.98, 0.97,0.96,0.95,0.93,0.92,0.90,0.85,0.8,0.75,0.7,0.65,0.6]:
                # visualize samples and density
                fig_filename = os.path.join(chkdir, "generated-T%g.jpg" % t)
                utils.makedirs(os.path.dirname(fig_filename))
                generated_samples = model(t * fixed_z, reverse=True)
                x = unshift(generated_samples[0].view(-1, *data_shape), 8)
                save_image(x, fig_filename, nrow=nb)
def get_dataset(args):
    trans = lambda im_size: tforms.Compose([tforms.Resize(im_size)])

    if args.data == "mnist":
        im_dim = 1
        im_size = 28 if args.imagesize is None else args.imagesize
        train_set = dset.MNIST(root=args.datadir, train=True, transform=trans(im_size), download=True)
        test_set = dset.MNIST(root=args.datadir, train=False, transform=trans(im_size), download=True)
    elif args.data == "svhn":
        im_dim = 3
        im_size = 32 if args.imagesize is None else args.imagesize
        train_set = dset.SVHN(root=args.datadir, split="train", transform=trans(im_size), download=True)
        test_set = dset.SVHN(root=args.datadir, split="test", transform=trans(im_size), download=True)
    elif args.data == "cifar10":
        im_dim = 3
        im_size = 32 if args.imagesize is None else args.imagesize
        train_set = dset.CIFAR10(
            root=args.datadir, train=True, transform=tforms.Compose([
                tforms.Resize(im_size),
                tforms.RandomHorizontalFlip(),
            ]), download=True
        )
        test_set = dset.CIFAR10(root=args.datadir, train=False, transform=None, download=True)
    elif args.data == 'celebahq':
        im_dim = 3
        im_size = 256 if args.imagesize is None else args.imagesize
        ''' 
        train_set = CelebAHQ(
            train=True, root=args.datadir, transform=tforms.Compose([
                tforms.ToPILImage(),
                tforms.Resize(im_size),
                tforms.RandomHorizontalFlip(),
            ])
        )
        test_set = CelebAHQ(
            train=False, root=args.datadir,  transform=tforms.Compose([
                tforms.ToPILImage(),
                tforms.Resize(im_size),
            ])
        )
        '''
        train_set = CelebAHQ(train=True, root=args.datadir)
        test_set = CelebAHQ(train=False, root=args.datadir)
    elif args.data == 'imagenet64':
        im_dim = 3
        if args.imagesize != 64:
            args.imagesize = 64
        im_size = 64
        train_set = Imagenet64(train=True, root=args.datadir)
        test_set = Imagenet64(train=False, root=args.datadir)
    elif args.data == 'lsun_church':
        im_dim = 3
        im_size = 64 if args.imagesize is None else args.imagesize
        train_set = dset.LSUN(
            'data', ['church_outdoor_train'], transform=tforms.Compose([
                tforms.Resize(96),
                tforms.RandomCrop(64),
                tforms.Resize(im_size),
            ])
        )
        test_set = dset.LSUN(
            'data', ['church_outdoor_val'], transform=tforms.Compose([
                tforms.Resize(96),
                tforms.RandomCrop(64),
                tforms.Resize(im_size),
            ])
        )
    data_shape = (im_dim, im_size, im_size)

    def fast_collate(batch):

        imgs = [img[0] for img in batch]
        targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
        w = imgs[0].size[0]
        h = imgs[0].size[1]

        tensor = torch.zeros((len(imgs), im_dim, im_size, im_size), dtype=torch.uint8)
        for i, img in enumerate(imgs):
            nump_array = np.asarray(img, dtype=np.uint8)
            tens = torch.from_numpy(nump_array)
            if (nump_array.ndim < 3):
                nump_array = np.expand_dims(nump_array, axis=-1)
            nump_array = np.rollaxis(nump_array, 2)
            tensor[i] += torch.from_numpy(nump_array)

        return tensor, targets

    train_sampler = (DistributedSampler(train_set,
                                        num_replicas=env_world_size(), rank=env_rank()) if args.distributed
                     else None)

    if not args.distributed:
        train_loader = torch.utils.data.DataLoader(
            dataset=train_set, batch_size=args.batch_size, shuffle=True,
            num_workers=args.nworkers, pin_memory=True, collate_fn=fast_collate
        )
    else:
        train_loader = torch.utils.data.DataLoader(
            dataset=train_set, batch_size=args.batch_size, sampler=train_sampler,
            num_workers=args.nworkers, pin_memory=True, collate_fn=fast_collate
        )

    # import pdb
    # pdb.set_trace()

    test_sampler = (DistributedSampler(test_set,
                                       num_replicas=env_world_size(), rank=env_rank()) if args.distributed
                    else None)

    if not args.distributed:
        test_loader = torch.utils.data.DataLoader(
            dataset=test_set, batch_size=args.test_batch_size, shuffle=False,
            num_workers=args.nworkers, pin_memory=True, collate_fn=fast_collate
        )
    else:
        test_loader = torch.utils.data.DataLoader(
            dataset=test_set, batch_size=args.test_batch_size,
            num_workers=args.nworkers, pin_memory=True, sampler=test_sampler, collate_fn=fast_collate
        )
    return train_loader, test_loader, data_shape
    #                    help='Shutdown instance at the end of training or failure')
    # parser.add_argument('--auto-shutdown-success-delay-mins', default=10, type=int,
    #                    help='how long to wait until shutting down on success')
    # parser.add_argument('--auto-shutdown-failure-delay-mins', default=60, type=int,
    #                    help='how long to wait before shutting down on error')

    return parser


cudnn.benchmark = True
args = get_parser().parse_args()
# torch.manual_seed(args.seed)
nvals = 2 ** args.nbits

# Only want master rank logging
is_master = (not args.distributed) or (dist_utils.env_rank() == 0)
is_rank0 = args.local_rank == 0
write_log = is_rank0 and is_master


def add_noise(x, nbits=8):
    if nbits < 8:
        x = x // (2 ** (8 - nbits))
    if args.add_noise:
        noise = x.new().resize_as_(x).uniform_()
    else:
        noise = 1 / 2
    return x.add_(noise).div_(2 ** nbits)


def shift(x, nbits=8):
예제 #7
0
def main():
    #os.system('shutdown -c')  # cancel previous shutdown command

    if write_log:
        utils.makedirs(args.save)
        logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'),
                                  filepath=os.path.abspath(__file__))

        logger.info(args)

        args_file_path = os.path.join(args.save, 'args.yaml')
        with open(args_file_path, 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

    if args.distributed:
        if write_log: logger.info('Distributed initializing process group')
        torch.cuda.set_device(args.local_rank)
        distributed.init_process_group(backend=args.dist_backend,
                                       init_method=args.dist_url,
                                       world_size=dist_utils.env_world_size(),
                                       rank=env_rank())
        assert (dist_utils.env_world_size() == distributed.get_world_size())
        if write_log:
            logger.info("Distributed: success (%d/%d)" %
                        (args.local_rank, distributed.get_world_size()))

    # get deivce
    # device = torch.device("cuda:%d"%torch.cuda.current_device() if torch.cuda.is_available() else "cpu")
    device = "cpu"
    cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

    # load dataset
    train_loader, test_loader, data_shape = get_dataset(args)

    trainlog = os.path.join(args.save, 'training.csv')
    testlog = os.path.join(args.save, 'test.csv')

    traincolumns = [
        'itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time',
        'grad_norm'
    ]
    testcolumns = [
        'wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time',
        'transport_cost'
    ]

    # build model
    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = create_model(args, data_shape, regularization_fns)
    # model = model.cuda()
    if args.distributed:
        model = dist_utils.DDP(model,
                               device_ids=[args.local_rank],
                               output_device=args.local_rank)

    traincolumns = append_regularization_keys_header(traincolumns,
                                                     regularization_fns)

    if not args.resume and write_log:
        with open(trainlog, 'w') as f:
            csvlogger = csv.DictWriter(f, traincolumns)
            csvlogger.writeheader()
        with open(testlog, 'w') as f:
            csvlogger = csv.DictWriter(f, testcolumns)
            csvlogger.writeheader()

    set_cnf_options(args, model)

    if write_log: logger.info(model)
    if write_log:
        logger.info("Number of trainable parameters: {}".format(
            count_parameters(model)))
    if write_log:
        logger.info('Iters per train epoch: {}'.format(len(train_loader)))
    if write_log: logger.info('Iters per test: {}'.format(len(test_loader)))

    # optimizer
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay,
                              momentum=0.9,
                              nesterov=False)

    # restore parameters
    if args.resume is not None:
        checkpt = torch.load(
            args.resume,
            map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpt["state_dict"])
        if "optim_state_dict" in checkpt.keys():
            optimizer.load_state_dict(checkpt["optim_state_dict"])
            # Manually move optimizer state to device.
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = cvt(v)

    # For visualization.
    if write_log:
        fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape))

    if write_log:
        time_meter = utils.RunningAverageMeter(0.97)
        bpd_meter = utils.RunningAverageMeter(0.97)
        loss_meter = utils.RunningAverageMeter(0.97)
        steps_meter = utils.RunningAverageMeter(0.97)
        grad_meter = utils.RunningAverageMeter(0.97)
        tt_meter = utils.RunningAverageMeter(0.97)

    if not args.resume:
        best_loss = float("inf")
        itr = 0
        wall_clock = 0.
        begin_epoch = 1
    else:
        chkdir = os.path.dirname(args.resume)
        tedf = pd.read_csv(os.path.join(chkdir, 'test.csv'))
        trdf = pd.read_csv(os.path.join(chkdir, 'training.csv'))
        wall_clock = trdf['wall'].to_numpy()[-1]
        itr = trdf['itr'].to_numpy()[-1]
        best_loss = tedf['bpd'].min()
        begin_epoch = int(tedf['epoch'].to_numpy()[-1] +
                          1)  # not exactly correct

    if args.distributed:
        if write_log: logger.info('Syncing machines before training')
        dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())

    for epoch in range(begin_epoch, args.num_epochs + 1):
        if not args.validate:
            model.train()

            with open(trainlog, 'a') as f:
                if write_log: csvlogger = csv.DictWriter(f, traincolumns)

                for _, (x, y) in enumerate(train_loader):
                    start = time.time()
                    update_lr(optimizer, itr)
                    optimizer.zero_grad()

                    # cast data and move to device
                    x = add_noise(cvt(x), nbits=args.nbits)
                    #x = x.clamp_(min=0, max=1)
                    # compute loss
                    bpd, (x, z), reg_states = compute_bits_per_dim(x, model)
                    if np.isnan(bpd.data.item()):
                        raise ValueError('model returned nan during training')
                    elif np.isinf(bpd.data.item()):
                        raise ValueError('model returned inf during training')

                    loss = bpd
                    if regularization_coeffs:
                        reg_loss = sum(reg_state * coeff
                                       for reg_state, coeff in zip(
                                           reg_states, regularization_coeffs)
                                       if coeff != 0)
                        loss = loss + reg_loss
                    total_time = count_total_time(model)

                    loss.backward()
                    nfe_opt = count_nfe(model)
                    if write_log: steps_meter.update(nfe_opt)
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), args.max_grad_norm)

                    optimizer.step()

                    itr_time = time.time() - start
                    wall_clock += itr_time

                    batch_size = x.size(0)
                    metrics = torch.tensor([
                        1., batch_size,
                        loss.item(),
                        bpd.item(), nfe_opt, grad_norm, *reg_states
                    ]).float()

                    rv = tuple(torch.tensor(0.) for r in reg_states)

                    total_gpus, batch_total, r_loss, r_bpd, r_nfe, r_grad_norm, *rv = dist_utils.sum_tensor(
                        metrics).cpu().numpy()

                    if write_log:
                        time_meter.update(itr_time)
                        bpd_meter.update(r_bpd / total_gpus)
                        loss_meter.update(r_loss / total_gpus)
                        grad_meter.update(r_grad_norm / total_gpus)
                        tt_meter.update(total_time)

                        fmt = '{:.4f}'
                        logdict = {
                            'itr': itr,
                            'wall': fmt.format(wall_clock),
                            'itr_time': fmt.format(itr_time),
                            'loss': fmt.format(r_loss / total_gpus),
                            'bpd': fmt.format(r_bpd / total_gpus),
                            'total_time': fmt.format(total_time),
                            'fe': r_nfe / total_gpus,
                            'grad_norm': fmt.format(r_grad_norm / total_gpus),
                        }
                        if regularization_coeffs:
                            rv = tuple(v_ / total_gpus for v_ in rv)
                            logdict = append_regularization_csv_dict(
                                logdict, regularization_fns, rv)
                        csvlogger.writerow(logdict)

                        if itr % args.log_freq == 0:
                            log_message = (
                                "Itr {:06d} | Wall {:.3e}({:.2f}) | "
                                "Time/Itr {:.2f}({:.2f}) | BPD {:.2f}({:.2f}) | "
                                "Loss {:.2f}({:.2f}) | "
                                "FE {:.0f}({:.0f}) | Grad Norm {:.3e}({:.3e}) | "
                                "TT {:.2f}({:.2f})".format(
                                    itr, wall_clock, wall_clock / (itr + 1),
                                    time_meter.val, time_meter.avg,
                                    bpd_meter.val, bpd_meter.avg,
                                    loss_meter.val, loss_meter.avg,
                                    steps_meter.val, steps_meter.avg,
                                    grad_meter.val, grad_meter.avg,
                                    tt_meter.val, tt_meter.avg))
                            if regularization_coeffs:
                                log_message = append_regularization_to_log(
                                    log_message, regularization_fns, rv)
                            logger.info(log_message)

                    itr += 1

        # compute test loss
        model.eval()
        if args.local_rank == 0:
            utils.makedirs(args.save)
            torch.save(
                {
                    "args":
                    args,
                    "state_dict":
                    model.module.state_dict()
                    if torch.cuda.is_available() else model.state_dict(),
                    "optim_state_dict":
                    optimizer.state_dict(),
                    "fixed_z":
                    fixed_z.cpu()
                }, os.path.join(args.save, "checkpt.pth"))
        if epoch % args.val_freq == 0 or args.validate:
            with open(testlog, 'a') as f:
                if write_log: csvlogger = csv.DictWriter(f, testcolumns)
                with torch.no_grad():
                    start = time.time()
                    if write_log: logger.info("validating...")

                    lossmean = 0.
                    meandist = 0.
                    steps = 0
                    tt = 0.
                    for i, (x, y) in enumerate(test_loader):
                        sh = x.shape
                        x = shift(cvt(x), nbits=args.nbits)
                        loss, (x, z), _ = compute_bits_per_dim(x, model)
                        dist = (x.view(x.size(0), -1) -
                                z).pow(2).mean(dim=-1).mean()
                        meandist = i / (i + 1) * dist + meandist / (i + 1)
                        lossmean = i / (i + 1) * lossmean + loss / (i + 1)

                        tt = i / (i + 1) * tt + count_total_time(model) / (i +
                                                                           1)
                        steps = i / (i + 1) * steps + count_nfe(model) / (i +
                                                                          1)

                    loss = lossmean.item()
                    metrics = torch.tensor([1., loss, meandist, steps]).float()

                    total_gpus, r_bpd, r_mdist, r_steps = dist_utils.sum_tensor(
                        metrics).cpu().numpy()
                    eval_time = time.time() - start

                    if write_log:
                        fmt = '{:.4f}'
                        logdict = {
                            'epoch': epoch,
                            'eval_time': fmt.format(eval_time),
                            'bpd': fmt.format(r_bpd / total_gpus),
                            'wall': fmt.format(wall_clock),
                            'total_time': fmt.format(tt),
                            'transport_cost': fmt.format(r_mdist / total_gpus),
                            'fe': '{:.2f}'.format(r_steps / total_gpus)
                        }

                        csvlogger.writerow(logdict)

                        logger.info(
                            "Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}, Steps {:.4f}, TT {:.2f}, Transport Cost {:.2e}"
                            .format(epoch, eval_time, r_bpd / total_gpus,
                                    r_steps / total_gpus, tt,
                                    r_mdist / total_gpus))

                    loss = r_bpd / total_gpus

                    if loss < best_loss and args.local_rank == 0:
                        best_loss = loss
                        shutil.copyfile(os.path.join(args.save, "checkpt.pth"),
                                        os.path.join(args.save, "best.pth"))

            # visualize samples and density
            if write_log:
                with torch.no_grad():
                    fig_filename = os.path.join(args.save, "figs",
                                                "{:04d}.jpg".format(epoch))
                    utils.makedirs(os.path.dirname(fig_filename))
                    generated_samples, _, _ = model(fixed_z, reverse=True)
                    generated_samples = generated_samples.view(-1, *data_shape)
                    nb = int(np.ceil(np.sqrt(float(fixed_z.size(0)))))
                    save_image(unshift(generated_samples, nbits=args.nbits),
                               fig_filename,
                               nrow=nb)
            if args.validate:
                break