Example #1
0
    # Setting up logger
    logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, format='%(asctime)s %(message)s')
    LOGGER = logging.getLogger(__name__)

    args = args_parse()

    # Setting up tensorboard writer
    writer = SummaryWriter(log_dir=os.path.join(args.root_dir, 'runs'))

    # Set seed
    set_seed(args.seed)

    # Create training and sample dataset (to test out model and save images for)
    dataset_dir = os.path.join(args.root_dir, 'data/celebaHQ512')
    dataset_files = sample(os.listdir(dataset_dir), 10000)
    train_dataset = CelebaDataset(dataset_dir, dataset_files, WT=False)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=10, shuffle=True)
    sample_dataset = Subset(train_dataset, sample(range(len(train_dataset)), 8))
    sample_loader = DataLoader(sample_dataset, batch_size=8, shuffle=False) 
    
    if torch.cuda.is_available():
        device = 'cuda:0'
    else: 
        device = 'cpu'

    # Setting up WT & IWT filters
    filters = create_filters(device=device)

    # Create model, set filters for WT (calculating loss), and set device
    wt_model = WTVAE_512_2(z_dim=args.z_dim, num_wt=args.num_iwt)
    wt_model = wt_model.to(device)
Example #2
0
    # Accelerate training since fixed input sizes
    torch.backends.cudnn.benchmark = True

    logging.basicConfig(stream=sys.stdout,
                        level=logging.DEBUG,
                        format='%(asctime)s %(message)s')
    LOGGER = logging.getLogger(__name__)

    args = args_parse()

    # Set seed
    set_seed(args.seed)

    dataset_dir = os.path.join(args.root_dir, 'celeba64')
    train_dataset = CelebaDataset(dataset_dir,
                                  os.listdir(dataset_dir),
                                  WT=False)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=10,
                              shuffle=True)
    sample_dataset = Subset(train_dataset, sample(range(len(train_dataset)),
                                                  8))
    sample_loader = DataLoader(sample_dataset, batch_size=8, shuffle=False)

    if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
        devices = ['cuda:0', 'cuda:1']
    else:
        devices = ['cpu', 'cpu']

    if args.mask: