示例#1
0
def main_train_loop(save_dir, model, args):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    n_class = len(args.cates)
    #resume chekckpoint
    start_epoch = 0
    optimizer = initilize_optimizer(model, args)
    if args.resume_checkpoint is None and os.path.exists(
            os.path.join(save_dir, 'checkpoint-latest.pt')):
        args.resume_checkpoint = os.path.join(
            save_dir, 'checkpoint-latest.pt')  # use the latest checkpoint
    if args.resume_checkpoint is not None:
        if args.resume_optimizer:
            model, optimizer, start_epoch = resume(
                args.resume_checkpoint,
                model,
                optimizer,
                strict=(not args.resume_non_strict))
        else:
            model, _, start_epoch = resume(args.resume_checkpoint,
                                           model,
                                           optimizer=None,
                                           strict=(not args.resume_non_strict))
        print('Resumed from: ' + args.resume_checkpoint)

    #initilize dataset and load
    tr_dataset, te_dataset = get_datasets(args)

    train_sampler = None  # for non distributed training

    train_loader = torch.utils.data.DataLoader(dataset=tr_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=0,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True,
                                               worker_init_fn=np.random.seed(
                                                   args.seed))
    test_loader = torch.utils.data.DataLoader(dataset=te_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=0,
                                              pin_memory=True,
                                              drop_last=False,
                                              worker_init_fn=np.random.seed(
                                                  args.seed))

    #initialize the learning rate scheduler
    if args.scheduler == 'exponential':
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.exp_decay)
    elif args.scheduler == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=args.epochs // 2,
                                              gamma=0.1)
    elif args.scheduler == 'linear':

        def lambda_rule(ep):
            lr_l = 1.0 - max(0, ep - 0.5 * args.epochs) / float(
                0.5 * args.epochs)
            return lr_l

        scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                                lr_lambda=lambda_rule)
    else:
        assert 0, "args.schedulers should be either 'exponential' or 'linear'"

    #training starts from here
    tot_nelbo = []
    tot_kl_loss = []
    tot_x_reconst = []

    best_eval_metric = float('+inf')

    for epoch in range(start_epoch, args.epochs):
        # adjust the learning rate
        if (epoch + 1) % args.exp_decay_freq == 0:
            scheduler.step(epoch=epoch)
        #train for one epoch
        model.train()
        for bidx, data in enumerate(train_loader):
            idx_batch, tr_batch, te_batch = data['idx'], data[
                'train_points'], data['test_points']
            obj_type = data['cate_idx']
            y_one_hot = obj_type.new(
                np.eye(n_class)[obj_type]).to(device).float()
            step = bidx + len(train_loader) * epoch

            if args.random_rotate:
                tr_batch, _, _ = apply_random_rotation(
                    tr_batch, rot_axis=train_loader.dataset.gravity_axis)

            inputs = tr_batch.to(device)
            y_one_hot = y_one_hot.to(device)
            optimizer.zero_grad()
            inputs_dict = {'x': inputs, 'y_class': y_one_hot}
            ret = model(inputs_dict)
            loss, nelbo, kl_loss, x_reconst, cl_loss = ret['loss'], ret[
                'nelbo'], ret['kl_loss'], ret['x_reconst'], ret['cl_loss']
            loss.backward()
            optimizer.step()

            cur_loss = loss.cpu().item()
            cur_nelbo = nelbo.cpu().item()
            cur_kl_loss = kl_loss.cpu().item()
            cur_x_reconst = x_reconst.cpu().item()
            cur_cl_loss = cl_loss.cpu().item()
            tot_nelbo.append(cur_nelbo)
            tot_kl_loss.append(cur_kl_loss)
            tot_x_reconst.append(cur_x_reconst)
            if step % args.log_freq == 0:
                print(
                    "Epoch {0:6d} Step {1:12d} Loss {2:12.6f} Nelbo {3:12.6f} KL Loss {4:12.6f} Reconst Loss {5:12.6f} CL_Loss{6:12.6f}"
                    .format(epoch, step, cur_loss, cur_nelbo, cur_kl_loss,
                            cur_x_reconst, cur_cl_loss))

        #save checkpoint
        if (epoch + 1) % args.save_freq == 0:
            save(model, optimizer, epoch + 1,
                 os.path.join(save_dir, 'checkpoint-%d.pt' % epoch))
            save(model, optimizer, epoch + 1,
                 os.path.join(save_dir, 'checkpoint-latest.pt'))

            eval_metric = evaluate_model(model, te_dataset, args)
            train_metric = evaluate_model(model, tr_dataset, args)

            print('Checkpoint: Dev Reconst Loss:{0}, Train Reconst Loss:{1}'.
                  format(eval_metric, train_metric))
            if eval_metric < best_eval_metric:
                best_eval_metric = eval_metric
                save(model, optimizer, epoch + 1,
                     os.path.join(save_dir, 'checkpoint-best.pt'))
                print('new best model found!')

    save(model, optimizer, args.epochs,
         os.path.join(save_dir, 'checkpoint-latest.pt'))
    #save final visuliztion of 10 samples
    model.eval()
    with torch.no_grad():
        samples_A = model.reconstruct_input(inputs)  #sample_point(5)
        results = []
        for idx in range(5):
            res = visualize_point_clouds(
                samples_A[idx],
                tr_batch[idx],
                idx,
                pert_order=train_loader.dataset.display_axis_order)
            results.append(res)
        res = np.concatenate(results, axis=1)
        imsave(os.path.join(save_dir, 'images', '_epoch%d.png' % (epoch)),
               res.transpose((1, 2, 0)))

    #load the best model and compute eval metric:
    best_model_path = os.path.join(save_dir, 'checkpoint-best.pt')
    ckpt = torch.load(best_model_path)
    model.load_state_dict(ckpt['model'], strict=True)
    eval_metric = evaluate_model(model, te_dataset, args)
    train_metric = evaluate_model(model, tr_dataset, args)
    print(
        'Best model at epoch:{2} Dev Reconst Loss:{0}, Train Reconst Loss:{1}'.
        format(eval_metric, train_metric, ckpt['epoch']))
示例#2
0
def main_worker(gpu, save_dir, ngpus_per_node, args):
    # basic setup
    cudnn.benchmark = True
    args.gpu = gpu
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    if args.log_name is not None:
        log_dir = "runs/%s" % args.log_name
    else:
        log_dir = "runs/time-%d" % time.time()

    if not args.distributed or (args.rank % ngpus_per_node == 0):
        writer = SummaryWriter(logdir=log_dir)
    else:
        writer = None

    if not args.use_latent_flow:  # auto-encoder only
        args.prior_weight = 0
        args.entropy_weight = 0

    # multi-GPU setup
    model = PointFlow(args)
    if args.distributed:  # Multiple processes, single GPU per process
        if args.gpu is not None:

            def _transform_(m):
                return nn.parallel.DistributedDataParallel(
                    m,
                    device_ids=[args.gpu],
                    output_device=args.gpu,
                    check_reduction=True)

            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            model.multi_gpu_wrapper(_transform_)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = 0
        else:
            assert 0, "DistributedDataParallel constructor should always set the single device scope"
    elif args.gpu is not None:  # Single process, single GPU per process
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:  # Single process, multiple GPUs per process

        def _transform_(m):
            return nn.DataParallel(m)

        model = model.cuda()
        model.multi_gpu_wrapper(_transform_)

    # resume checkpoints
    start_epoch = 0
    optimizer = model.make_optimizer(args)
    if args.resume_checkpoint is None and os.path.exists(
            os.path.join(save_dir, 'checkpoint-latest.pt')):
        args.resume_checkpoint = os.path.join(
            save_dir, 'checkpoint-latest.pt')  # use the latest checkpoint
    if args.resume_checkpoint is not None:
        if args.resume_optimizer:
            model, optimizer, start_epoch = resume(
                args.resume_checkpoint,
                model,
                optimizer,
                strict=(not args.resume_non_strict))
        else:
            model, _, start_epoch = resume(args.resume_checkpoint,
                                           model,
                                           optimizer=None,
                                           strict=(not args.resume_non_strict))
        print('Resumed from: ' + args.resume_checkpoint)

    # initialize datasets and loaders
    tr_dataset = MyDataset(args.data_dir, istest=False)
    te_dataset = MyDataset(args.data_dir, istest=True)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            tr_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(dataset=tr_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=0,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True,
                                               worker_init_fn=init_np_seed)
    test_loader = torch.utils.data.DataLoader(dataset=te_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=0,
                                              pin_memory=True,
                                              drop_last=False,
                                              worker_init_fn=init_np_seed)

    # save dataset statistics
    # if not args.distributed or (args.rank % ngpus_per_node == 0):
    #     np.save(os.path.join(save_dir, "train_set_mean.npy"), tr_dataset.all_points_mean)
    #     np.save(os.path.join(save_dir, "train_set_std.npy"), tr_dataset.all_points_std)
    #     np.save(os.path.join(save_dir, "train_set_idx.npy"), np.array(tr_dataset.shuffle_idx))
    #     np.save(os.path.join(save_dir, "val_set_mean.npy"), te_dataset.all_points_mean)
    #     np.save(os.path.join(save_dir, "val_set_std.npy"), te_dataset.all_points_std)
    #     np.save(os.path.join(save_dir, "val_set_idx.npy"), np.array(te_dataset.shuffle_idx))

    # load classification dataset if needed
    if args.eval_classification:
        from datasets import get_clf_datasets

        def _make_data_loader_(dataset):
            return torch.utils.data.DataLoader(dataset=dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=0,
                                               pin_memory=True,
                                               drop_last=False,
                                               worker_init_fn=init_np_seed)

        clf_datasets = get_clf_datasets(args)
        clf_loaders = {
            k: [_make_data_loader_(ds) for ds in ds_lst]
            for k, ds_lst in clf_datasets.items()
        }
    else:
        clf_loaders = None

    # initialize the learning rate scheduler
    if args.scheduler == 'exponential':
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.exp_decay)
    elif args.scheduler == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=args.epochs // 2,
                                              gamma=0.1)
    elif args.scheduler == 'linear':

        def lambda_rule(ep):
            lr_l = 1.0 - max(0, ep - 0.5 * args.epochs) / float(
                0.5 * args.epochs)
            return lr_l

        scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                                lr_lambda=lambda_rule)
    else:
        assert 0, "args.schedulers should be either 'exponential' or 'linear'"

    # main training loop
    start_time = time.time()
    entropy_avg_meter = AverageValueMeter()
    latent_nats_avg_meter = AverageValueMeter()
    point_nats_avg_meter = AverageValueMeter()
    if args.distributed:
        print("[Rank %d] World size : %d" % (args.rank, dist.get_world_size()))

    print("Start epoch: %d End epoch: %d" % (start_epoch, args.epochs))
    for epoch in range(start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # adjust the learning rate
        if (epoch + 1) % args.exp_decay_freq == 0:
            scheduler.step(epoch=epoch)
            if writer is not None:
                writer.add_scalar('lr/optimizer', scheduler.get_lr()[0], epoch)

        # train for one epoch
        for bidx, data in enumerate(train_loader):
            idx_batch, tr_batch, te_batch = data['idx'], data[
                'train_points'], data['test_points']
            step = bidx + len(train_loader) * epoch
            model.train()
            inputs = tr_batch.cuda(args.gpu, non_blocking=True)
            out = model(inputs, optimizer, step, writer)
            entropy, prior_nats, recon_nats = out['entropy'], out[
                'prior_nats'], out['recon_nats']
            entropy_avg_meter.update(entropy)
            point_nats_avg_meter.update(recon_nats)
            latent_nats_avg_meter.update(prior_nats)
            if step % args.log_freq == 0:
                duration = time.time() - start_time
                start_time = time.time()
                print(
                    "[Rank %d] Epoch %d Batch [%2d/%2d] Time [%3.2fs] Entropy %2.5f LatentNats %2.5f PointNats %2.5f"
                    % (args.rank, epoch, bidx, len(train_loader), duration,
                       entropy_avg_meter.avg, latent_nats_avg_meter.avg,
                       point_nats_avg_meter.avg))

        # evaluate on the validation set
        # if not args.no_validation and (epoch + 1) % args.val_freq == 0:
        #     from utils import validate
        #     validate(test_loader, model, epoch, writer, save_dir, args, clf_loaders=clf_loaders)

        # save visualizations
        if (epoch + 1) % args.viz_freq == 0:
            # reconstructions
            model.eval()
            samples = model.reconstruct(inputs)
            results = []
            for idx in range(min(10, inputs.size(0))):
                res = visualize_point_clouds(samples[idx], inputs[idx], idx)
                results.append(res)
            res = np.concatenate(results, axis=1)
            scipy.misc.imsave(
                os.path.join(
                    save_dir, 'images',
                    'tr_vis_conditioned_epoch%d-gpu%s.png' %
                    (epoch, args.gpu)), res.transpose((1, 2, 0)))
            if writer is not None:
                writer.add_image('tr_vis/conditioned', torch.as_tensor(res),
                                 epoch)

            # samples
            if args.use_latent_flow:
                num_samples = min(10, inputs.size(0))
                num_points = inputs.size(1)
                _, samples = model.sample(num_samples, num_points)
                results = []
                for idx in range(num_samples):
                    res = visualize_point_clouds(samples[idx], inputs[idx],
                                                 idx)
                    results.append(res)
                res = np.concatenate(results, axis=1)
                scipy.misc.imsave(
                    os.path.join(
                        save_dir, 'images',
                        'tr_vis_conditioned_epoch%d-gpu%s.png' %
                        (epoch, args.gpu)), res.transpose((1, 2, 0)))
                if writer is not None:
                    writer.add_image('tr_vis/sampled', torch.as_tensor(res),
                                     epoch)

        # save checkpoints
        if not args.distributed or (args.rank % ngpus_per_node == 0):
            if (epoch + 1) % args.save_freq == 0:
                save(model, optimizer, epoch + 1,
                     os.path.join(save_dir, 'checkpoint-%d.pt' % epoch))
                save(model, optimizer, epoch + 1,
                     os.path.join(save_dir, 'checkpoint-latest.pt'))
示例#3
0
def main_worker(gpu, save_dir, ngpus_per_node, init_data, args):
    # basic setup
    cudnn.benchmark = True
    args.gpu = gpu
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    # resume training!!!
    #################################
    if args.resume_checkpoint is None and os.path.exists(os.path.join(save_dir, 'checkpoint-latest.pt')):
        args.resume_checkpoint = os.path.join(save_dir, 'checkpoint-latest.pt')  # use the latest checkpoint
        print('Checkpoint is set to the latest one.')
    #################################

    # multi-GPU setup
    model = SoftPointFlow(args)
    if args.distributed:  # Multiple processes, single GPU per process
        if args.gpu is not None:
            def _transform_(m):
                return nn.parallel.DistributedDataParallel(
                    m, device_ids=[args.gpu], output_device=args.gpu, check_reduction=True)

            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            model.multi_gpu_wrapper(_transform_)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = 0
        else:
            assert 0, "DistributedDataParallel constructor should always set the single device scope"
    else:  # Single process, multiple GPUs per process
        def _transform_(m):
            return nn.DataParallel(m)
        model = model.cuda()
        model.multi_gpu_wrapper(_transform_)

    start_epoch = 1
    valid_loss_best = 987654321
    optimizer = model.make_optimizer(args)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    if args.resume_checkpoint is not None:
        model, optimizer, scheduler, start_epoch, valid_loss_best, log_dir = resume(
            args.resume_checkpoint, model, optimizer, scheduler)
        model.set_initialized(True)
        print('Resumed from: ' + args.resume_checkpoint)

    else:
        log_dir = save_dir + "/runs/" + str(time.strftime('%Y-%m-%d_%H:%M:%S'))
        with torch.no_grad():
            inputs, inputs_noisy, std_in = init_data
            inputs = inputs.to(args.gpu, non_blocking=True)
            inputs_noisy = inputs_noisy.to(args.gpu, non_blocking=True)
            std_in = std_in.to(args.gpu, non_blocking=True)
            _ = model(inputs, inputs_noisy, std_in, optimizer,  None, None, init=True)
        del inputs, inputs_noisy, std_in
        print('Actnorm is initialized')

    if not args.distributed or (args.rank % ngpus_per_node == 0):
        writer = SummaryWriter(logdir=log_dir)
    else:
        writer = None

    # initialize datasets and loaders
    tr_dataset = get_trainset(args)
    te_dataset = get_testset(args)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(tr_dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(te_dataset)
    else:
        train_sampler = None
        test_sampler = None
        
    train_loader = torch.utils.data.DataLoader(
        dataset=tr_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=0, pin_memory=True, sampler=train_sampler, drop_last=True,
        worker_init_fn=init_np_seed)

    test_loader = torch.utils.data.DataLoader(
        dataset=te_dataset, batch_size=args.batch_size, shuffle=(test_sampler is None),
        num_workers=0, pin_memory=True, sampler=test_sampler, drop_last=True,
        worker_init_fn=init_np_seed)

    # save dataset statistics
    if not args.distributed or (args.rank % ngpus_per_node == 0):
        np.save(os.path.join(save_dir, "train_set_mean.npy"), tr_dataset.all_points_mean)
        np.save(os.path.join(save_dir, "train_set_std.npy"), tr_dataset.all_points_std)
        np.save(os.path.join(save_dir, "train_set_idx.npy"), np.array(tr_dataset.shuffle_idx))
    
    # main training loop
    if args.distributed:
        print("[Rank %d] World size : %d" % (args.rank, dist.get_world_size()))

    seen_inputs = next(iter(train_loader))['train_points'].cuda(args.gpu, non_blocking=True)
    unseen_inputs = next(iter(test_loader))['test_points'].cuda(args.gpu, non_blocking=True)
    del test_loader

    print("Start epoch: %d End epoch: %d" % (start_epoch, args.epochs))
    for epoch in range(start_epoch, args.epochs+1):
        start_time = time.time()
        if args.distributed:
            train_sampler.set_epoch(epoch)

        if writer is not None:
            writer.add_scalar('lr/optimizer', scheduler.get_lr()[0], epoch)

        model.train()
        # train for one epoch
        
        for bidx, data in enumerate(train_loader):
            step = bidx + len(train_loader) * (epoch - 1)
            tr_batch = data['train_points']
            if args.random_rotate:
                tr_batch, _, _ = apply_random_rotation(
                    tr_batch, rot_axis=train_loader.dataset.gravity_axis)

            inputs = tr_batch.cuda(args.gpu, non_blocking=True)
            B, N, D = inputs.shape
            std = (args.std_max - args.std_min) * torch.rand_like(inputs[:,:,0]).view(B,N,1) + args.std_min

            eps = torch.randn_like(inputs) * std
            std_in = std / args.std_max * args.std_scale
            inputs_noisy = inputs + eps
            out = model(inputs, inputs_noisy, std_in, optimizer, step, writer)
            entropy, prior_nats, recon_nats, loss = out['entropy'], out['prior_nats'], out['recon_nats'], out['loss']
            if step % args.log_freq == 0:
                duration = time.time() - start_time
                start_time = time.time()
                if writer is not None:
                    writer.add_scalar('train/avg_time', duration, step)
                print("[Rank %d] Epoch %d Batch [%2d/%2d] Time [%3.2fs] Entropy %2.5f LatentNats %2.5f PointNats %2.5f loss %2.5f"
                      % (args.rank, epoch, bidx, len(train_loader), duration, entropy,
                         prior_nats, recon_nats, loss))
            del inputs, inputs_noisy, std_in, out, eps
            gc.collect()

        if epoch < args.stop_scheduler:
            scheduler.step()

        if epoch % args.valid_freq == 0:
            with torch.no_grad():
                model.eval()
                valid_loss = 0.0
                valid_entropy = 0.0
                valid_prior = 0.0
                valid_prior_nats = 0.0
                valid_recon = 0.0
                valid_recon_nats = 0.0
                for bidx, data in enumerate(train_loader):
                    step = bidx + len(train_loader) * epoch
                    tr_batch = data['test_points']
                    if args.random_rotate:
                        tr_batch, _, _ = apply_random_rotation(
                            tr_batch, rot_axis=train_loader.dataset.gravity_axis)

                    inputs = tr_batch.cuda(args.gpu, non_blocking=True)
                    B, N, D = inputs.shape
                    std = (args.std_max - args.std_min) * torch.rand_like(inputs[:,:,0]).view(B,N,1) + args.std_min

                    eps = torch.randn_like(inputs) * std
                    std_in = std / args.std_max * args.std_scale
                    inputs_noisy = inputs + eps
                    out = model(inputs, inputs_noisy, std_in, optimizer, step, writer, valid=True)
                    valid_loss += out['loss'] / len(train_loader)
                    valid_entropy += out['entropy'] / len(train_loader)
                    valid_prior += out['prior'] / len(train_loader)
                    valid_prior_nats += out['prior_nats'] / len(train_loader)
                    valid_recon += out['recon'] / len(train_loader)
                    valid_recon_nats += out['recon_nats'] / len(train_loader)
                    del inputs, inputs_noisy, std_in, out, eps
                    gc.collect()

                if writer is not None:
                    writer.add_scalar('valid/entropy', valid_entropy, epoch)
                    writer.add_scalar('valid/prior', valid_prior, epoch)
                    writer.add_scalar('valid/prior(nats)', valid_prior_nats, epoch)
                    writer.add_scalar('valid/recon', valid_recon, epoch)
                    writer.add_scalar('valid/recon(nats)', valid_recon_nats, epoch)
                    writer.add_scalar('valid/loss', valid_loss, epoch)
                
                duration = time.time() - start_time
                start_time = time.time()
                print("[Valid] Epoch %d Time [%3.2fs] Entropy %2.5f LatentNats %2.5f PointNats %2.5f loss %2.5f loss_best %2.5f"
                    % (epoch, duration, valid_entropy, valid_prior_nats, valid_recon_nats, valid_loss, valid_loss_best))
                if valid_loss < valid_loss_best:
                    valid_loss_best = valid_loss
                    if not args.distributed or (args.rank % ngpus_per_node == 0):
                        save(model, optimizer, epoch + 1, scheduler, valid_loss_best, log_dir,
                            os.path.join(save_dir, 'checkpoint-best.pt'))
                        print('best model saved!')

        if epoch % args.save_freq == 0 and (not args.distributed or (args.rank % ngpus_per_node == 0)):
            save(model, optimizer, epoch + 1, scheduler, valid_loss_best, log_dir,
                os.path.join(save_dir, 'checkpoint-%d.pt' % epoch))
            save(model, optimizer, epoch + 1, scheduler, valid_loss_best, log_dir,
                os.path.join(save_dir, 'checkpoint-latest.pt'))
            print('model saved!')

        # save visualizations
        if epoch % args.viz_freq == 0:
            with torch.no_grad():
                # reconstructions
                model.eval()
                samples = model.reconstruct(unseen_inputs)
                results = []
                for idx in range(min(16, unseen_inputs.size(0))):
                    res = visualize_point_clouds(samples[idx], unseen_inputs[idx], idx,
                                                pert_order=train_loader.dataset.display_axis_order)

                    results.append(res)
                res = np.concatenate(results, axis=1)
                imageio.imwrite(os.path.join(save_dir, 'images', 'SPF_epoch%d-gpu%s_recon_unseen.png' % (epoch, args.gpu)),
                                res.transpose(1, 2, 0))
                if writer is not None:
                    writer.add_image('tr_vis/conditioned', torch.as_tensor(res), epoch)

                samples = model.reconstruct(seen_inputs)
                results = []
                for idx in range(min(16, seen_inputs.size(0))):
                    res = visualize_point_clouds(samples[idx], seen_inputs[idx], idx,
                                                pert_order=train_loader.dataset.display_axis_order)

                    results.append(res)
                res = np.concatenate(results, axis=1)
                imageio.imwrite(os.path.join(save_dir, 'images', 'SPF_epoch%d-gpu%s_recon_seen.png' % (epoch, args.gpu)),
                                res.transpose(1, 2, 0))
                if writer is not None:
                    writer.add_image('tr_vis/conditioned', torch.as_tensor(res), epoch)

                num_samples = min(16, unseen_inputs.size(0))
                num_points = unseen_inputs.size(1)
                _, samples = model.sample(num_samples, num_points)
                results = []
                for idx in range(num_samples):
                    res = visualize_point_clouds(samples[idx], unseen_inputs[idx], idx,
                                                pert_order=train_loader.dataset.display_axis_order)
                    results.append(res)
                res = np.concatenate(results, axis=1)
                imageio.imwrite(os.path.join(save_dir, 'images', 'SPF_epoch%d-gpu%s_sample.png' % (epoch, args.gpu)),
                                res.transpose((1, 2, 0)))
                if writer is not None:
                    writer.add_image('tr_vis/sampled', torch.as_tensor(res), epoch)
                
                print('image saved!')