Пример #1
0
def main_train_loop(save_dir, model, args):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    n_class = len(args.cates)
    #resume chekckpoint
    start_epoch = 0
    optimizer = initilize_optimizer(model, args)
    if args.resume_checkpoint is None and os.path.exists(
            os.path.join(save_dir, 'checkpoint-latest.pt')):
        args.resume_checkpoint = os.path.join(
            save_dir, 'checkpoint-latest.pt')  # use the latest checkpoint
    if args.resume_checkpoint is not None:
        if args.resume_optimizer:
            model, optimizer, start_epoch = resume(
                args.resume_checkpoint,
                model,
                optimizer,
                strict=(not args.resume_non_strict))
        else:
            model, _, start_epoch = resume(args.resume_checkpoint,
                                           model,
                                           optimizer=None,
                                           strict=(not args.resume_non_strict))
        print('Resumed from: ' + args.resume_checkpoint)

    #initilize dataset and load
    tr_dataset, te_dataset = get_datasets(args)

    train_sampler = None  # for non distributed training

    train_loader = torch.utils.data.DataLoader(dataset=tr_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=0,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True,
                                               worker_init_fn=np.random.seed(
                                                   args.seed))
    test_loader = torch.utils.data.DataLoader(dataset=te_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=0,
                                              pin_memory=True,
                                              drop_last=False,
                                              worker_init_fn=np.random.seed(
                                                  args.seed))

    #initialize the learning rate scheduler
    if args.scheduler == 'exponential':
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.exp_decay)
    elif args.scheduler == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=args.epochs // 2,
                                              gamma=0.1)
    elif args.scheduler == 'linear':

        def lambda_rule(ep):
            lr_l = 1.0 - max(0, ep - 0.5 * args.epochs) / float(
                0.5 * args.epochs)
            return lr_l

        scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                                lr_lambda=lambda_rule)
    else:
        assert 0, "args.schedulers should be either 'exponential' or 'linear'"

    #training starts from here
    tot_nelbo = []
    tot_kl_loss = []
    tot_x_reconst = []

    best_eval_metric = float('+inf')

    for epoch in range(start_epoch, args.epochs):
        # adjust the learning rate
        if (epoch + 1) % args.exp_decay_freq == 0:
            scheduler.step(epoch=epoch)
        #train for one epoch
        model.train()
        for bidx, data in enumerate(train_loader):
            idx_batch, tr_batch, te_batch = data['idx'], data[
                'train_points'], data['test_points']
            obj_type = data['cate_idx']
            y_one_hot = obj_type.new(
                np.eye(n_class)[obj_type]).to(device).float()
            step = bidx + len(train_loader) * epoch

            if args.random_rotate:
                tr_batch, _, _ = apply_random_rotation(
                    tr_batch, rot_axis=train_loader.dataset.gravity_axis)

            inputs = tr_batch.to(device)
            y_one_hot = y_one_hot.to(device)
            optimizer.zero_grad()
            inputs_dict = {'x': inputs, 'y_class': y_one_hot}
            ret = model(inputs_dict)
            loss, nelbo, kl_loss, x_reconst, cl_loss = ret['loss'], ret[
                'nelbo'], ret['kl_loss'], ret['x_reconst'], ret['cl_loss']
            loss.backward()
            optimizer.step()

            cur_loss = loss.cpu().item()
            cur_nelbo = nelbo.cpu().item()
            cur_kl_loss = kl_loss.cpu().item()
            cur_x_reconst = x_reconst.cpu().item()
            cur_cl_loss = cl_loss.cpu().item()
            tot_nelbo.append(cur_nelbo)
            tot_kl_loss.append(cur_kl_loss)
            tot_x_reconst.append(cur_x_reconst)
            if step % args.log_freq == 0:
                print(
                    "Epoch {0:6d} Step {1:12d} Loss {2:12.6f} Nelbo {3:12.6f} KL Loss {4:12.6f} Reconst Loss {5:12.6f} CL_Loss{6:12.6f}"
                    .format(epoch, step, cur_loss, cur_nelbo, cur_kl_loss,
                            cur_x_reconst, cur_cl_loss))

        #save checkpoint
        if (epoch + 1) % args.save_freq == 0:
            save(model, optimizer, epoch + 1,
                 os.path.join(save_dir, 'checkpoint-%d.pt' % epoch))
            save(model, optimizer, epoch + 1,
                 os.path.join(save_dir, 'checkpoint-latest.pt'))

            eval_metric = evaluate_model(model, te_dataset, args)
            train_metric = evaluate_model(model, tr_dataset, args)

            print('Checkpoint: Dev Reconst Loss:{0}, Train Reconst Loss:{1}'.
                  format(eval_metric, train_metric))
            if eval_metric < best_eval_metric:
                best_eval_metric = eval_metric
                save(model, optimizer, epoch + 1,
                     os.path.join(save_dir, 'checkpoint-best.pt'))
                print('new best model found!')

    save(model, optimizer, args.epochs,
         os.path.join(save_dir, 'checkpoint-latest.pt'))
    #save final visuliztion of 10 samples
    model.eval()
    with torch.no_grad():
        samples_A = model.reconstruct_input(inputs)  #sample_point(5)
        results = []
        for idx in range(5):
            res = visualize_point_clouds(
                samples_A[idx],
                tr_batch[idx],
                idx,
                pert_order=train_loader.dataset.display_axis_order)
            results.append(res)
        res = np.concatenate(results, axis=1)
        imsave(os.path.join(save_dir, 'images', '_epoch%d.png' % (epoch)),
               res.transpose((1, 2, 0)))

    #load the best model and compute eval metric:
    best_model_path = os.path.join(save_dir, 'checkpoint-best.pt')
    ckpt = torch.load(best_model_path)
    model.load_state_dict(ckpt['model'], strict=True)
    eval_metric = evaluate_model(model, te_dataset, args)
    train_metric = evaluate_model(model, tr_dataset, args)
    print(
        'Best model at epoch:{2} Dev Reconst Loss:{0}, Train Reconst Loss:{1}'.
        format(eval_metric, train_metric, ckpt['epoch']))
Пример #2
0
def main_worker(gpu, save_dir, ngpus_per_node, args):
    # basic setup
    cudnn.benchmark = True
    args.gpu = gpu
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    if args.log_name is not None:
        log_dir = "runs/%s" % args.log_name
    else:
        log_dir = "runs/time-%d" % time.time()

    if not args.distributed or (args.rank % ngpus_per_node == 0):
        writer = SummaryWriter(logdir=log_dir)
    else:
        writer = None

    if not args.use_latent_flow:  # auto-encoder only
        args.prior_weight = 0
        args.entropy_weight = 0

    # multi-GPU setup
    model = PointFlow(args)
    if args.distributed:  # Multiple processes, single GPU per process
        if args.gpu is not None:

            def _transform_(m):
                return nn.parallel.DistributedDataParallel(
                    m,
                    device_ids=[args.gpu],
                    output_device=args.gpu,
                    check_reduction=True)

            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            model.multi_gpu_wrapper(_transform_)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = 0
        else:
            assert 0, "DistributedDataParallel constructor should always set the single device scope"
    elif args.gpu is not None:  # Single process, single GPU per process
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:  # Single process, multiple GPUs per process

        def _transform_(m):
            return nn.DataParallel(m)

        model = model.cuda()
        model.multi_gpu_wrapper(_transform_)

    # resume checkpoints
    start_epoch = 0
    optimizer = model.make_optimizer(args)
    if args.resume_checkpoint is None and os.path.exists(
            os.path.join(save_dir, 'checkpoint-latest.pt')):
        args.resume_checkpoint = os.path.join(
            save_dir, 'checkpoint-latest.pt')  # use the latest checkpoint
    if args.resume_checkpoint is not None:
        if args.resume_optimizer:
            model, optimizer, start_epoch = resume(
                args.resume_checkpoint,
                model,
                optimizer,
                strict=(not args.resume_non_strict))
        else:
            model, _, start_epoch = resume(args.resume_checkpoint,
                                           model,
                                           optimizer=None,
                                           strict=(not args.resume_non_strict))
        print('Resumed from: ' + args.resume_checkpoint)

    # initialize datasets and loaders
    tr_dataset, te_dataset = get_datasets(args)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            tr_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(dataset=tr_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=0,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True,
                                               worker_init_fn=init_np_seed)
    test_loader = torch.utils.data.DataLoader(dataset=te_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=0,
                                              pin_memory=True,
                                              drop_last=True,
                                              worker_init_fn=init_np_seed)

    # save dataset statistics
    if not args.distributed or (args.rank % ngpus_per_node == 0):
        np.save(os.path.join(save_dir, "train_set_mean.npy"),
                tr_dataset.all_points_mean)
        np.save(os.path.join(save_dir, "train_set_std.npy"),
                tr_dataset.all_points_std)
        np.save(os.path.join(save_dir, "train_set_idx.npy"),
                np.array(tr_dataset.shuffle_idx))
        np.save(os.path.join(save_dir, "val_set_mean.npy"),
                te_dataset.all_points_mean)
        np.save(os.path.join(save_dir, "val_set_std.npy"),
                te_dataset.all_points_std)
        np.save(os.path.join(save_dir, "val_set_idx.npy"),
                np.array(te_dataset.shuffle_idx))

    # load classification dataset if needed
    if args.eval_classification:
        from datasets import get_clf_datasets

        def _make_data_loader_(dataset):
            return torch.utils.data.DataLoader(dataset=dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=0,
                                               pin_memory=True,
                                               drop_last=True,
                                               worker_init_fn=init_np_seed)

        clf_datasets = get_clf_datasets(args)
        clf_loaders = {
            k: [_make_data_loader_(ds) for ds in ds_lst]
            for k, ds_lst in clf_datasets.items()
        }
    else:
        clf_loaders = None

    # initialize the learning rate scheduler
    if args.scheduler == 'exponential':
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.exp_decay)
    elif args.scheduler == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=args.epochs // 2,
                                              gamma=0.1)
    elif args.scheduler == 'linear':

        def lambda_rule(ep):
            lr_l = 1.0 - max(0, ep - 0.5 * args.epochs) / float(
                0.5 * args.epochs)
            return lr_l

        scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                                lr_lambda=lambda_rule)
    else:
        assert 0, "args.schedulers should be either 'exponential' or 'linear'"

    # main training loop
    start_time = time.time()
    entropy_avg_meter = AverageValueMeter()
    latent_nats_avg_meter = AverageValueMeter()
    point_nats_avg_meter = AverageValueMeter()
    if args.distributed:
        print("[Rank %d] World size : %d" % (args.rank, dist.get_world_size()))

    print("Start epoch: %d End epoch: %d" % (start_epoch, args.epochs))
    for epoch in range(start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # adjust the learning rate
        if (epoch + 1) % args.exp_decay_freq == 0:
            scheduler.step(epoch=epoch)
            if writer is not None:
                writer.add_scalar('lr/optimizer', scheduler.get_lr()[0], epoch)

        # train for one epoch
        for bidx, data in enumerate(train_loader):
            # if bidx > 100:
            #     break
            idx_batch, tr_batch, te_batch = data['idx'], data[
                'train_points'], data['test_points']
            step = bidx + len(train_loader) * epoch
            model.train()
            if args.random_rotate:
                tr_batch, _, _ = apply_random_rotation(
                    tr_batch, rot_axis=train_loader.dataset.gravity_axis)
            inputs = tr_batch.cuda(args.gpu, non_blocking=True)
            out = model(inputs, optimizer, step, writer)
            entropy, prior_nats, recon_nats = out['entropy'], out[
                'prior_nats'], out['recon_nats']
            entropy_avg_meter.update(entropy)
            point_nats_avg_meter.update(recon_nats)
            latent_nats_avg_meter.update(prior_nats)
            if step % args.log_freq == 0:
                duration = time.time() - start_time
                start_time = time.time()
                print(
                    "[Rank %d] Epoch %d Batch [%2d/%2d] Time [%3.2fs] Entropy %2.5f LatentNats %2.5f PointNats %2.5f"
                    % (args.rank, epoch, bidx, len(train_loader), duration,
                       entropy_avg_meter.avg, latent_nats_avg_meter.avg,
                       point_nats_avg_meter.avg))

        # evaluate on the validation set
        if not args.no_validation and (epoch + 1) % args.val_freq == 0:
            from utils import validate
            validate(test_loader,
                     model,
                     epoch,
                     writer,
                     save_dir,
                     args,
                     clf_loaders=clf_loaders)

        # save visualizations
        if (epoch + 1) % args.viz_freq == 0:
            # reconstructions
            model.eval()
            samples = model.reconstruct(inputs)
            results = []
            for idx in range(min(10, inputs.size(0))):
                res = visualize_point_clouds(
                    samples[idx],
                    inputs[idx],
                    idx,
                    pert_order=train_loader.dataset.display_axis_order)
                results.append(res)
            res = np.concatenate(results, axis=1)
            imwrite(
                os.path.join(
                    save_dir, 'images',
                    'tr_vis_conditioned_epoch%d-gpu%s.png' %
                    (epoch, args.gpu)), res.transpose((1, 2, 0)))
            if writer is not None:
                writer.add_image('tr_vis/conditioned', torch.as_tensor(res),
                                 epoch)

            # samples
            if args.use_latent_flow:
                num_samples = min(10, inputs.size(0))
                num_points = inputs.size(1)
                _, samples = model.sample(num_samples, num_points)
                results = []
                for idx in range(num_samples):
                    res = visualize_point_clouds(
                        samples[idx],
                        inputs[idx],
                        idx,
                        pert_order=train_loader.dataset.display_axis_order)
                    results.append(res)
                res = np.concatenate(results, axis=1)
                imwrite(
                    os.path.join(
                        save_dir, 'images',
                        'tr_vis_conditioned_epoch%d-gpu%s.png' %
                        (epoch, args.gpu)), res.transpose((1, 2, 0)))
                if writer is not None:
                    writer.add_image('tr_vis/sampled', torch.as_tensor(res),
                                     epoch)

        # save checkpoints
        if not args.distributed or (args.rank % ngpus_per_node == 0):
            if (epoch + 1) % args.save_freq == 0:
                save(model, optimizer, epoch + 1,
                     os.path.join(save_dir, 'checkpoint-%d.pt' % epoch))
                save(model, optimizer, epoch + 1,
                     os.path.join(save_dir, 'checkpoint-latest.pt'))
Пример #3
0
def main():
    torch.backends.cudnn.benchmark = True

    # hyper-params initializing
    args = dictobj()
    args.gpu = torch.device('cuda:%d' % (6))
    timestamp = '%d-%d-%d-%d-%d-%d-%d-%d-%d' % time.localtime(time.time())
    args.log_name = '%s-pointflow' % timestamp
    writer = SummaryWriter(comment=args.log_name)

    args.use_latent_flow, args.prior_w, args.entropy_w, args.recon_w = True, 1., 1., 1.
    args.fin, args.fz = 3, 128
    args.use_deterministic_encoder = True
    args.distributed = False
    args.optimizer = optim.Adam
    args.batch_size = 16
    args.lr, args.beta1, args.beta2, args.weight_decay = 1e-3, 0.9, 0.999, 1e-4
    args.T, args.train_T, args.atol, args.rtol = 1., False, 1e-5, 1e-5
    args.layer_type = diffop.CoScaleLinear
    args.solver = 'dopri5'
    args.use_adjoint, args.bn = True, False
    args.dims, args.num_blocks = (512, 512), 1  # originally (512 * 3)
    args.latent_dims, args.latent_num_blocks = (256, 256), 1

    args.resume, args.resume_path = False, None
    args.end_epoch = 2000
    args.scheduler, args.scheduler_step_size = optim.lr_scheduler.StepLR, 20
    args.random_rotation = True
    args.save_freq = 10

    args.dataset_type = 'shapenet15k'
    args.cates = ['airplane']  # 'all' for all categories training
    args.tr_max_sample_points, args.te_max_sample_points = 2048, 2048
    args.dataset_scale = 1.0
    args.normalize_per_shape = False
    args.normalize_std_per_axis = False
    args.num_workers = 4
    args.data_dir = "/data/ShapeNetCore.v2.PC15k"

    torch.cuda.set_device(args.gpu)
    model = PointFlow(**args).cuda(args.gpu)

    # load milestone
    epoch = 0
    optimizer = model.get_optimizer(**args)
    if args.resume:
        model, optimizer, epoch = resume(args.resume_path,
                                         model,
                                         optimizer,
                                         strict=True)
        print("Loaded model from %s" % args.resume_path)

    # load data
    train_dataset, test_dataset = get_datasets(args)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               sampler=None,
                                               drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              sampler=None,
                                              drop_last=False)

    if args.scheduler == optim.lr_scheduler.StepLR:
        scheduler = optim.lr_scheduler.StepLR(
            optimizer, step_size=args.scheduler_step_size, gamma=0.65)
    else:
        raise NotImplementedError("Only StepLR supported")

    ent_rec, latent_rec, recon_rec = Averager(), Averager(), Averager()
    for e in trange(epoch, args.end_epoch):
        # record lr
        if writer is not None:
            writer.add_scalar('lr/optimizer', scheduler.get_lr()[0], e)

        # feed a batch, train
        for idx, data in enumerate(tqdm(train_loader)):
            idx_batch, tr_batch, te_batch = data['idx'], data[
                'train_points'], data['test_points']
            model.train()
            if args.random_rotation:
                # raise NotImplementedError('Random Rotation Augmentation not implemented yet')
                tr_batch, _, _ = apply_random_rotation(
                    tr_batch, rot_axis=train_loader.dataset.gravity_axis)
            inputs = tr_batch.cuda(args.gpu, non_blocking=True)
            step = idx + len(train_loader) * e  # batch step
            out = model(inputs, optimizer, step, writer, sample_gpu=args.gpu)
            entropy, prior_nats, recon_nats = out['entropy'], out[
                'prior_nats'], out['recon_nats']
            ent_rec.update(entropy)
            recon_rec.update(recon_nats)
            latent_rec.update(prior_nats)

        # update lr
        scheduler.step(epoch=e)

        # save milestones
        if e % args.save_freq == 0 and e != 0:
            save(model, optimizer, e, path='milestone-%d.save' % e)
            save(model, optimizer, e,
                 path='milestone-latest.save' % e)  # save as latest model
Пример #4
0
def main_worker(gpu, save_dir, ngpus_per_node, init_data, args):
    # basic setup
    cudnn.benchmark = True
    args.gpu = gpu
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    # resume training!!!
    #################################
    if args.resume_checkpoint is None and os.path.exists(os.path.join(save_dir, 'checkpoint-latest.pt')):
        args.resume_checkpoint = os.path.join(save_dir, 'checkpoint-latest.pt')  # use the latest checkpoint
        print('Checkpoint is set to the latest one.')
    #################################

    # multi-GPU setup
    model = SoftPointFlow(args)
    if args.distributed:  # Multiple processes, single GPU per process
        if args.gpu is not None:
            def _transform_(m):
                return nn.parallel.DistributedDataParallel(
                    m, device_ids=[args.gpu], output_device=args.gpu, check_reduction=True)

            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            model.multi_gpu_wrapper(_transform_)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = 0
        else:
            assert 0, "DistributedDataParallel constructor should always set the single device scope"
    else:  # Single process, multiple GPUs per process
        def _transform_(m):
            return nn.DataParallel(m)
        model = model.cuda()
        model.multi_gpu_wrapper(_transform_)

    start_epoch = 1
    valid_loss_best = 987654321
    optimizer = model.make_optimizer(args)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    if args.resume_checkpoint is not None:
        model, optimizer, scheduler, start_epoch, valid_loss_best, log_dir = resume(
            args.resume_checkpoint, model, optimizer, scheduler)
        model.set_initialized(True)
        print('Resumed from: ' + args.resume_checkpoint)

    else:
        log_dir = save_dir + "/runs/" + str(time.strftime('%Y-%m-%d_%H:%M:%S'))
        with torch.no_grad():
            inputs, inputs_noisy, std_in = init_data
            inputs = inputs.to(args.gpu, non_blocking=True)
            inputs_noisy = inputs_noisy.to(args.gpu, non_blocking=True)
            std_in = std_in.to(args.gpu, non_blocking=True)
            _ = model(inputs, inputs_noisy, std_in, optimizer,  None, None, init=True)
        del inputs, inputs_noisy, std_in
        print('Actnorm is initialized')

    if not args.distributed or (args.rank % ngpus_per_node == 0):
        writer = SummaryWriter(logdir=log_dir)
    else:
        writer = None

    # initialize datasets and loaders
    tr_dataset = get_trainset(args)
    te_dataset = get_testset(args)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(tr_dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(te_dataset)
    else:
        train_sampler = None
        test_sampler = None
        
    train_loader = torch.utils.data.DataLoader(
        dataset=tr_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=0, pin_memory=True, sampler=train_sampler, drop_last=True,
        worker_init_fn=init_np_seed)

    test_loader = torch.utils.data.DataLoader(
        dataset=te_dataset, batch_size=args.batch_size, shuffle=(test_sampler is None),
        num_workers=0, pin_memory=True, sampler=test_sampler, drop_last=True,
        worker_init_fn=init_np_seed)

    # save dataset statistics
    if not args.distributed or (args.rank % ngpus_per_node == 0):
        np.save(os.path.join(save_dir, "train_set_mean.npy"), tr_dataset.all_points_mean)
        np.save(os.path.join(save_dir, "train_set_std.npy"), tr_dataset.all_points_std)
        np.save(os.path.join(save_dir, "train_set_idx.npy"), np.array(tr_dataset.shuffle_idx))
    
    # main training loop
    if args.distributed:
        print("[Rank %d] World size : %d" % (args.rank, dist.get_world_size()))

    seen_inputs = next(iter(train_loader))['train_points'].cuda(args.gpu, non_blocking=True)
    unseen_inputs = next(iter(test_loader))['test_points'].cuda(args.gpu, non_blocking=True)
    del test_loader

    print("Start epoch: %d End epoch: %d" % (start_epoch, args.epochs))
    for epoch in range(start_epoch, args.epochs+1):
        start_time = time.time()
        if args.distributed:
            train_sampler.set_epoch(epoch)

        if writer is not None:
            writer.add_scalar('lr/optimizer', scheduler.get_lr()[0], epoch)

        model.train()
        # train for one epoch
        
        for bidx, data in enumerate(train_loader):
            step = bidx + len(train_loader) * (epoch - 1)
            tr_batch = data['train_points']
            if args.random_rotate:
                tr_batch, _, _ = apply_random_rotation(
                    tr_batch, rot_axis=train_loader.dataset.gravity_axis)

            inputs = tr_batch.cuda(args.gpu, non_blocking=True)
            B, N, D = inputs.shape
            std = (args.std_max - args.std_min) * torch.rand_like(inputs[:,:,0]).view(B,N,1) + args.std_min

            eps = torch.randn_like(inputs) * std
            std_in = std / args.std_max * args.std_scale
            inputs_noisy = inputs + eps
            out = model(inputs, inputs_noisy, std_in, optimizer, step, writer)
            entropy, prior_nats, recon_nats, loss = out['entropy'], out['prior_nats'], out['recon_nats'], out['loss']
            if step % args.log_freq == 0:
                duration = time.time() - start_time
                start_time = time.time()
                if writer is not None:
                    writer.add_scalar('train/avg_time', duration, step)
                print("[Rank %d] Epoch %d Batch [%2d/%2d] Time [%3.2fs] Entropy %2.5f LatentNats %2.5f PointNats %2.5f loss %2.5f"
                      % (args.rank, epoch, bidx, len(train_loader), duration, entropy,
                         prior_nats, recon_nats, loss))
            del inputs, inputs_noisy, std_in, out, eps
            gc.collect()

        if epoch < args.stop_scheduler:
            scheduler.step()

        if epoch % args.valid_freq == 0:
            with torch.no_grad():
                model.eval()
                valid_loss = 0.0
                valid_entropy = 0.0
                valid_prior = 0.0
                valid_prior_nats = 0.0
                valid_recon = 0.0
                valid_recon_nats = 0.0
                for bidx, data in enumerate(train_loader):
                    step = bidx + len(train_loader) * epoch
                    tr_batch = data['test_points']
                    if args.random_rotate:
                        tr_batch, _, _ = apply_random_rotation(
                            tr_batch, rot_axis=train_loader.dataset.gravity_axis)

                    inputs = tr_batch.cuda(args.gpu, non_blocking=True)
                    B, N, D = inputs.shape
                    std = (args.std_max - args.std_min) * torch.rand_like(inputs[:,:,0]).view(B,N,1) + args.std_min

                    eps = torch.randn_like(inputs) * std
                    std_in = std / args.std_max * args.std_scale
                    inputs_noisy = inputs + eps
                    out = model(inputs, inputs_noisy, std_in, optimizer, step, writer, valid=True)
                    valid_loss += out['loss'] / len(train_loader)
                    valid_entropy += out['entropy'] / len(train_loader)
                    valid_prior += out['prior'] / len(train_loader)
                    valid_prior_nats += out['prior_nats'] / len(train_loader)
                    valid_recon += out['recon'] / len(train_loader)
                    valid_recon_nats += out['recon_nats'] / len(train_loader)
                    del inputs, inputs_noisy, std_in, out, eps
                    gc.collect()

                if writer is not None:
                    writer.add_scalar('valid/entropy', valid_entropy, epoch)
                    writer.add_scalar('valid/prior', valid_prior, epoch)
                    writer.add_scalar('valid/prior(nats)', valid_prior_nats, epoch)
                    writer.add_scalar('valid/recon', valid_recon, epoch)
                    writer.add_scalar('valid/recon(nats)', valid_recon_nats, epoch)
                    writer.add_scalar('valid/loss', valid_loss, epoch)
                
                duration = time.time() - start_time
                start_time = time.time()
                print("[Valid] Epoch %d Time [%3.2fs] Entropy %2.5f LatentNats %2.5f PointNats %2.5f loss %2.5f loss_best %2.5f"
                    % (epoch, duration, valid_entropy, valid_prior_nats, valid_recon_nats, valid_loss, valid_loss_best))
                if valid_loss < valid_loss_best:
                    valid_loss_best = valid_loss
                    if not args.distributed or (args.rank % ngpus_per_node == 0):
                        save(model, optimizer, epoch + 1, scheduler, valid_loss_best, log_dir,
                            os.path.join(save_dir, 'checkpoint-best.pt'))
                        print('best model saved!')

        if epoch % args.save_freq == 0 and (not args.distributed or (args.rank % ngpus_per_node == 0)):
            save(model, optimizer, epoch + 1, scheduler, valid_loss_best, log_dir,
                os.path.join(save_dir, 'checkpoint-%d.pt' % epoch))
            save(model, optimizer, epoch + 1, scheduler, valid_loss_best, log_dir,
                os.path.join(save_dir, 'checkpoint-latest.pt'))
            print('model saved!')

        # save visualizations
        if epoch % args.viz_freq == 0:
            with torch.no_grad():
                # reconstructions
                model.eval()
                samples = model.reconstruct(unseen_inputs)
                results = []
                for idx in range(min(16, unseen_inputs.size(0))):
                    res = visualize_point_clouds(samples[idx], unseen_inputs[idx], idx,
                                                pert_order=train_loader.dataset.display_axis_order)

                    results.append(res)
                res = np.concatenate(results, axis=1)
                imageio.imwrite(os.path.join(save_dir, 'images', 'SPF_epoch%d-gpu%s_recon_unseen.png' % (epoch, args.gpu)),
                                res.transpose(1, 2, 0))
                if writer is not None:
                    writer.add_image('tr_vis/conditioned', torch.as_tensor(res), epoch)

                samples = model.reconstruct(seen_inputs)
                results = []
                for idx in range(min(16, seen_inputs.size(0))):
                    res = visualize_point_clouds(samples[idx], seen_inputs[idx], idx,
                                                pert_order=train_loader.dataset.display_axis_order)

                    results.append(res)
                res = np.concatenate(results, axis=1)
                imageio.imwrite(os.path.join(save_dir, 'images', 'SPF_epoch%d-gpu%s_recon_seen.png' % (epoch, args.gpu)),
                                res.transpose(1, 2, 0))
                if writer is not None:
                    writer.add_image('tr_vis/conditioned', torch.as_tensor(res), epoch)

                num_samples = min(16, unseen_inputs.size(0))
                num_points = unseen_inputs.size(1)
                _, samples = model.sample(num_samples, num_points)
                results = []
                for idx in range(num_samples):
                    res = visualize_point_clouds(samples[idx], unseen_inputs[idx], idx,
                                                pert_order=train_loader.dataset.display_axis_order)
                    results.append(res)
                res = np.concatenate(results, axis=1)
                imageio.imwrite(os.path.join(save_dir, 'images', 'SPF_epoch%d-gpu%s_sample.png' % (epoch, args.gpu)),
                                res.transpose((1, 2, 0)))
                if writer is not None:
                    writer.add_image('tr_vis/sampled', torch.as_tensor(res), epoch)
                
                print('image saved!')