Ejemplo n.º 1
0
def get_scheduler(lr, epoch, sched, optimizer, loader):
    scheduler = None
    if sched == 'cycle':
        scheduler = CycleScheduler(optimizer,
                                   lr,
                                   n_iter=len(loader) * epoch,
                                   momentum=None)
    return scheduler
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    normMean = [0.5]
    normStd = [0.5]

    normTransform = transforms.Normalize(normMean, normStd)
    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.ToTensor(),
        normTransform,
    ])

    txt_path = 'datd/train.txt'
    images_path = '/data'
    labels_path = '/data'

    dataset = txtDataset(txt_path,
                         images_path,
                         labels_path,
                         transform=transform)

    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=batch_size // args.n_gpu,
                        sampler=sampler,
                        num_workers=16)

    model = VQVAE().to(device)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
Ejemplo n.º 3
0
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    transforms = video_transforms.Compose([
        RandomSelectFrames(16),
        video_transforms.Resize(args.size),
        video_transforms.CenterCrop(args.size),
        volume_transforms.ClipToTensor(),
        tensor_transforms.Normalize(0.5, 0.5)
    ])

    f = open(
        '/home/shirakawa/movie/code/iVideoGAN/over16frame_list_training.txt',
        'rb')
    train_file_list = pickle.load(f)
    print(len(train_file_list))

    dataset = MITDataset(train_file_list, transform=transforms)
    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    #loader = DataLoader(
    #    dataset, batch_size=128 // args.n_gpu, sampler=sampler, num_workers=2
    #)
    loader = DataLoader(dataset,
                        batch_size=32 // args.n_gpu,
                        sampler=sampler,
                        num_workers=2)

    model = VQVAE().to(device)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint_vid_v2/vqvae_{str(i + 1).zfill(3)}.pt")
Ejemplo n.º 4
0
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dataset = datasets.ImageFolder(args.path, transform=transform)
    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=128 // args.n_gpu,
                        sampler=sampler,
                        num_workers=2)

    model = VQVAE().to(device)

    if args.load_path:
        load_state_dict = torch.load(args.load_path, map_location=device)
        model.load_state_dict(load_state_dict)
        print('successfully loaded model')

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
Ejemplo n.º 5
0
    args = parser.parse_args()

    print(args)

    device = 'cuda'

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dataset = datasets.ImageFolder(args.path, transform=transform)
    loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

    model = nn.DataParallel(VQVAE()).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == 'cycle':
        scheduler = CycleScheduler(optimizer,
                                   args.lr,
                                   n_iter=len(loader) * args.epoch,
                                   momentum=None)

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)
        torch.save(model.module.state_dict(),
                   f'checkpoint/vqvae_{str(i + 1).zfill(3)}.pt')
Ejemplo n.º 6
0
def main():
    # model_name = "mne-interp-mc"
    # model_name = "vq-2-mc"
    model_name = "nn"
    # model_name = "cvq-2-mc"
    # model_name = HOME_PATH + "reconstruction/saved_models/" + "vq-2-mc"

    if check_valid_filename(model_name):
        # the model name is filepath to the model
        saved_model = True
    else:
        saved_model = False

    z_dim = 30
    lr = 1e-3
    # sched = 'cycle'
    sched = None

    num_epochs = 200
    batch_size = 64
    num_examples_train = -1
    num_examples_eval = -1

    device = 'cuda'
    # normalize = True
    log_interval = 1
    tb_save_loc = "runs/testing/"
    select_channels = [0]  #[0,1,2,3]
    num_channels = len(select_channels)

    lengths = {
        # Single Channel Outputs
        "nn": 784,
        "cnn": 784,
        "vq": 784,
        "vq-2": 1024,
        "unet": 1024,

        # Multichannel Outputs
        "cnn-mc": 784,
        "vq-2-mc": 1024,
        "cvq-2-mc": 1023,

        # Baselines
        "avg-interp": 784,
        "mne-interp": 784,
        "mne-interp-mc": 1023,
    }

    models = {
        # "nn" : cVAE1c(z_dim=z_dim),
        "nn":
        VAE1c(z_dim=z_dim),
        "vq-2":
        cVQVAE_2(in_channel=1),
        "unet":
        UNet(in_channels=num_channels),
        "vq-2-mc":
        VQVAE_2(in_channel=num_channels),
        "cvq-2-mc":
        cVQVAE_2(in_channel=num_channels),
        "cnn-mc":
        ConvVAE(num_channels=num_channels,
                num_channels_out=num_channels,
                z_dim=z_dim),
        "avg-interp":
        AvgInterpolation(),
        "mne-interp":
        MNEInterpolation(),
        "mne-interp-mc":
        MNEInterpolationMC(),
    }

    model_filenames = {
        "nn": HOME_PATH + "models/VAE1c.py",
        "cnn": HOME_PATH + "models/conv_VAE.py",
        "vq": HOME_PATH + "models/VQ_VAE_1c.py",
        "vq-2": HOME_PATH + "models/vq-vae-2-pytorch/vqvae.py",
        "unet": HOME_PATH + "models/unet.py",
        "cnn-mc": HOME_PATH + "models/conv_VAE.py",
        "vq-2-mc": HOME_PATH + "models/vq-vae-2-pytorch/vqvae.py",
        "mne-interp-mc": HOME_PATH + "denoise/fill_baseline_models.py",
    }

    if saved_model:
        model = torch.load(model_name)
        length = 1024  # TODO find way to set auto
    else:
        model = models[model_name]
        length = lengths[model_name]

    if model_name == "mne-interp" or model_name == "mne-interp-mc":
        select_channels = [0, 1, 2, 3]
    # else:
    # select_channels = [0,1,2]
    # select_channels = [0,1]#,2,3]

    train_files = TRAIN_NORMAL_FILES_CSV  #TRAIN_FILES_CSV
    # train_files =  TRAIN_FILES_CSV

    eval_files = DEV_NORMAL_FILES_CSV  #DEV_FILES_CSV
    # eval_files =  DEV_FILES_CSV
    eval_dataset = EEGDatasetMc(eval_files,
                                max_num_examples=num_examples_eval,
                                length=length,
                                normalize=normalize,
                                select_channels=select_channels)
    eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)

    target_filename = model_name
    run_filename = find_valid_filename(target_filename,
                                       HOME_PATH + 'denoise/' + tb_save_loc)
    tb_filename = tb_save_loc + run_filename
    writer = SummaryWriter(tb_save_loc + run_filename)

    model = model.to(device)

    try:
        optimizer = optim.Adam(model.parameters(), lr=lr)
        train_model = True
    except ValueError:
        print("This Model Cannot Be Optimized")
        train_model = False
        sched = None

    if train_model:
        train_dataset = EEGDatasetMc(train_files,
                                     max_num_examples=num_examples_train,
                                     length=length,
                                     normalize=normalize,
                                     select_channels=select_channels)
        train_loader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True)
        print("m Train Dataset", len(train_dataset))

    print("m Eval Dataset", len(eval_dataset))

    if saved_model:
        train_model = True

    scheduler = None

    if sched == 'cycle':
        scheduler = CycleScheduler(optimizer,
                                   lr,
                                   n_iter=len(train_loader) * num_epochs,
                                   momentum=None)
    for i in range(1, num_epochs + 1):
        if train_model:
            train(i,
                  train_loader,
                  model,
                  optimizer,
                  scheduler,
                  device,
                  writer,
                  log_interval=log_interval)
        eval(i, eval_loader, model, device, writer, log_interval=log_interval)

    save_dir = HOME_PATH + "denoise/saved_runs/" + str(int(time.time())) + "/"
    recon_file = HOME_PATH + "denoise/fill_1c.py"
    train_file = HOME_PATH + "denoise/train_fill_1c.py"
    model_filename = model_filenames[model_name]
    python_files = [recon_file, train_file, model_filename]

    info_dict = {
        "model_name": model_name,
        "z_dim": z_dim,
        "lr": lr,
        "sched": sched,
        "num_epochs": num_epochs,
        "batch_size": batch_size,
        "num_examples_train": num_examples_train,
        "num_examples_eval": num_examples_eval,
        "train_files": train_files,
        "eval_files": eval_files,
        "device": device,
        "normalize": normalize.__name__,
        "log_interval": log_interval,
        "tb_dirpath": tb_filename
    }

    save_run(save_dir, python_files, model, info_dict)
    for key, value in info_dict.items():
        print(key + ":", value)
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    normMean = [0.5]
    normStd = [0.5]

    normTransform = transforms.Normalize(normMean, normStd)
    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.ToTensor(),
        normTransform,
    ])

    txt_path = './data/train.txt'
    images_path = './data'
    labels_path = './data'

    dataset = txtDataset(txt_path,
                         images_path,
                         labels_path,
                         transform=transform)

    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=batch_size // args.n_gpu,
                        sampler=sampler,
                        num_workers=16)

    # Initialize generator and discriminator
    DpretrainedPath = './checkpoint/vqvae2GAN_040.pt'
    GpretrainedPath = './checkpoint/vqvae_040.pt'

    discriminator = Discriminator()
    generator = Generator()
    if os.path.exists(DpretrainedPath):
        print('Loading model weights...')
        discriminator.load_state_dict(
            torch.load(DpretrainedPath)['discriminator'])
        print('done')
    if os.path.exists(GpretrainedPath):
        print('Loading model weights...')
        generator.load_state_dict(torch.load(GpretrainedPath))
        print('done')

    discriminator = discriminator.to(device)
    generator = generator.to(device)

    if args.distributed:
        discriminator = nn.parallel.DistributedDataParallel(
            discriminator,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    if args.distributed:
        generator = nn.parallel.DistributedDataParallel(
            generator,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer_D = optim.Adam(discriminator.parameters(), lr=args.lr)
    optimizer_G = optim.Adam(generator.parameters(), lr=args.lr)
    scheduler_D = None
    scheduler_G = None
    if args.sched == "cycle":
        scheduler_D = CycleScheduler(
            optimizer_D,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

        scheduler_G = CycleScheduler(
            optimizer_G,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(41, args.epoch):
        train(i, loader, discriminator, generator, scheduler_D, scheduler_G,
              optimizer_D, optimizer_G, device)

        if dist.is_primary():
            torch.save(
                {
                    'generator': generator.state_dict(),
                    'discriminator': discriminator.state_dict(),
                    'g_optimizer': optimizer_G.state_dict(),
                    'd_optimizer': optimizer_D.state_dict(),
                },
                f'checkpoint/vqvae2GAN_{str(i + 1).zfill(3)}.pt',
            )
            if (i + 1) % n_critic == 0:
                torch.save(generator.state_dict(),
                           f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dataset = OffsetDataset(args.path, transform=transform, offset=args.offset)
    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=args.bsize // args.n_gpu,
                        sampler=sampler,
                        num_workers=2)

    # Load pre-trained VQVAE
    vqvae = VQVAE().to(device)
    try:
        vqvae.load_state_dict(torch.load(args.ckpt))
    except:
        print(
            "Seems the checkpoint was trained with data parallel, try loading it that way"
        )
        weights = torch.load(args.ckpt)
        renamed_weights = {}
        for key, value in weights.items():
            renamed_weights[key.replace('module.', '')] = value
        weights = renamed_weights
        vqvae.load_state_dict(weights)

    # Init offset encoder
    model = OffsetNetwork(vqvae).to(device)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
            find_unused_parameters=True)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint/offset_enc_{str(i + 1).zfill(3)}.pt")
Ejemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--size', help='Image size', type=int, default=256)
    parser.add_argument('--epoch', type=int, default=100)
    parser.add_argument('--lr', type=float, default=3e-4)
    parser.add_argument('--bs', type=int, default=64)
    parser.add_argument('--sched', type=str, default='cycle')
    parser.add_argument('--vishost', type=str, default='localhost')
    parser.add_argument('--visport', type=int, default=8097)
    parser.add_argument('path',
                        help="root path with train and test folder in it",
                        type=str)

    args = parser.parse_args()

    print(args)

    device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5] * 3, [0.5] * 3)
    ])

    train_path = os.path.join(args.path, "train")
    test_path = os.path.join(args.path, "test")

    train_dataset = datasets.ImageFolder(train_path, transform=transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.bs,
                              shuffle=True,
                              num_workers=4)

    test_dataset = datasets.ImageFolder(test_path, transform=transform)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.bs,
                             shuffle=False,
                             num_workers=4)

    model = VQVAE().to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    if args.sched == 'cycle':
        scheduler = CycleScheduler(optimizer,
                                   args.lr,
                                   n_iter=len(train_loader) * args.epoch,
                                   momentum=None)
    else:
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [50, 70], 0.1)

    train_losses = []
    test_losses = []
    vis = visdom.Visdom(server=args.vishost, port=args.visport)
    win = None
    best_model_loss = np.inf
    for i in range(args.epoch):
        # Training stage
        print(f"Training epoch {i + 1}")
        train_loss = train(i, train_loader, model, optimizer, scheduler,
                           device)
        print(f"Train Loss: {train_loss:.5f}")

        # Testing stage
        print(f"Testing epoch {i + 1}")
        test_loss, test_recon_error, test_commitment_loss = test(
            i, test_loader, model, device)
        print(f"Test Loss: {test_loss:.5f}")
        torch.save(model.state_dict(),
                   f'checkpoints/vqvae_chkpt_{str(i + 1).zfill(3)}.pt')

        if test_loss < best_model_loss:
            print("Saving model")
            torch.save(model.state_dict(), f'weights/vqvae.pt')
            best_model_loss = test_loss

        train_losses.append(train_loss)
        test_losses.append(test_loss)
        win = plot(train_losses, test_losses, vis, win)

        # Sampling stage
        recon_sample(i, model, test_loader, device)
Ejemplo n.º 10
0
        )

    optimizer = optim.SGD(
        model.parameters(),
        lr=args.lr / 25,
        momentum=0.9,
        weight_decay=args.l2,
        nesterov=True,
    )

    max_iter = math.ceil(len(train_set) / (n_gpu * args.batch)) * args.epoch

    scheduler = CycleScheduler(
        optimizer,
        args.lr,
        n_iter=max_iter,
        warmup_proportion=0.01,
        phase=('linear', 'poly'),
    )

    train_loader = DataLoader(
        train_set,
        batch_size=args.batch,
        num_workers=2,
        sampler=data_sampler(train_set, shuffle=True, distributed=args.distributed),
    )
    valid_loader = DataLoader(
        valid_set,
        batch_size=args.batch,
        num_workers=2,
        sampler=data_sampler(valid_set, shuffle=False, distributed=args.distributed),
Ejemplo n.º 11
0
def main():
	model_name = "cnn"
	z_dim = 30
	lr = 1e-3
	sched = None

	num_epochs = 1000
	batch_size = 64
	num_examples_train = -1
	num_examples_eval = -1
	
	device = 'cuda'
	normalize = True
	
	log_interval = 2
	tb_save_loc = "runs/testing/"

	lengths = {
		"nn" : 784,
		"cnn" : 784,
		"vq" : 784,
		"vq-2" : 1024,
		"unet" : 1024,
	}

	models = {
		"nn" : VAE1c(z_dim=z_dim),
		"cnn" : ConvVAE(num_channels=1, z_dim=z_dim),
		"vq" : VQVAE(hidden=z_dim),
		"vq-2" : VQVAE_2(in_channel=1),
		"unet" : UNet()
	}


	length = lengths[model_name]
	model = models[model_name]

	train_files =  TRAIN_NORMAL_FILES_CSV #TRAIN_FILES_CSV 
	train_dataset = EEGDataset1c(train_files, max_num_examples=num_examples_train, length=length, normalize=normalize)
	train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

	eval_files =  DEV_NORMAL_FILES_CSV #DEV_FILES_CSV 
	eval_dataset = EEGDataset1c(eval_files, max_num_examples=num_examples_eval, length=length, normalize=normalize)
	eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)
	

	print("m Train Dataset", len(train_dataset))
	print("m Eval Dataset", len(eval_dataset))

	target_filename = model_name
	run_filename = find_valid_filename(target_filename, HOME_PATH + 'reconstruction/' + tb_save_loc)
	writer = SummaryWriter(tb_save_loc + run_filename)

	model = model.to(device)

	optimizer = optim.Adam(model.parameters(), lr=lr)
	scheduler = None

	if sched == 'cycle':
		scheduler = CycleScheduler(
			optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None
		)
	for i in range(1, num_epochs + 1):
		train(i, train_loader, model, optimizer, scheduler, device, writer, log_interval=log_interval)
		eval(i, eval_loader, model, device, writer, log_interval=log_interval)
Ejemplo n.º 12
0
def create_training_run(size, num_epochs, lr, sched, dataset, architecture, data_path, device, num_embeddings, neighborhood, selection_fn, num_workers, vae_batch, eval_iter, embed_dim, parallelize, download, **kwargs):
    experiment_name = create_experiment_name(architecture, dataset, num_embeddings, neighborhood, selection_fn=selection_fn, size=size, lr=lr, **kwargs)
    _today = getattr(args, 'today_str', _TODAY)
    log_arguments(**vars(args))
    writer = SummaryWriter(os.path.join(args.summary_path, _today, experiment_name))

    print('Loading datasets')
    train_dataset, test_dataset = get_dataset(dataset, data_path, size, download)

    print('Creating loaders')
    train_loader = DataLoader(train_dataset, batch_size=vae_batch, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=vae_batch, shuffle=True, num_workers=num_workers)

    print('Initializing models')
    model = get_model(architecture, num_embeddings, device, neighborhood, selection_fn, embed_dim, parallelize, **kwargs)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = None
    if sched == 'cycle':
        scheduler = CycleScheduler(
            optimizer, lr, n_iter=len(train_loader) * num_epochs, momentum=None
        )

    train_mses = []
    train_losses = []
    test_mses = []
    test_losses = []

    for epoch_ind in range(1, num_epochs+1):
        avg_mse, avg_loss, avg_quantization_steps, avg_Z, avg_D, avg_norm_Z, avg_top_percentile, avg_num_zeros = do_epoch(epoch_ind, train_loader, model, writer, experiment_name, device, optimizer, scheduler, dictionary_loss_weight=kwargs['dictionary_loss_weight'], sampling_iter=kwargs['sampling_iter'],sample_size=kwargs['sample_size'])
        train_mses.append(avg_mse)
        train_losses.append(avg_loss)
        writer.add_scalar('train/avg_mse', avg_mse, epoch_ind)
        writer.add_scalar('train/avg_loss', avg_loss, epoch_ind)
        writer.add_scalar('train/avg_quantization_steps', avg_quantization_steps, epoch_ind)
        writer.add_scalar('train/avg_Z', avg_Z, epoch_ind)
        writer.add_scalar('train/avg_D', avg_D, epoch_ind)
        writer.add_scalar('train/avg_norm_Z', avg_norm_Z, epoch_ind)
        writer.add_scalar('train/avg_top_percentile', avg_top_percentile, epoch_ind)
        writer.add_scalar('train/avg_num_zeros', avg_num_zeros, epoch_ind)

        if epoch_ind % kwargs['checkpoint_freq'] == 0:

            cp_path = os.path.join(args.checkpoint_path, _today, create_checkpoint_name(experiment_name, epoch_ind))
            os.makedirs(osp.dirname(cp_path), exist_ok=True)
            if parallelize:  # If using DataParallel we need to access the inner module
                torch.save(model.module.state_dict(), cp_path)
            else:
                torch.save(model.state_dict(), cp_path)

        if epoch_ind % eval_iter == 0:
            avg_mse, avg_loss, avg_quantization_steps, avg_Z, avg_D, avg_norm_Z, avg_top_percentile, avg_num_zeros = do_epoch(epoch_ind, test_loader, model, writer, experiment_name, device, optimizer, scheduler, phase='test')

            test_mses.append(avg_mse)
            test_losses.append(avg_loss)
            writer.add_scalar('test/avg_loss', avg_loss, epoch_ind)
            writer.add_scalar('test/avg_mse', avg_mse, epoch_ind)
            writer.add_scalar('test/avg_quantization_steps', avg_quantization_steps, epoch_ind)
            writer.add_scalar('test/avg_Z', avg_Z, epoch_ind)
            writer.add_scalar('test/avg_D', avg_D, epoch_ind)
            writer.add_scalar('test/avg_norm_Z', avg_norm_Z, epoch_ind)
            writer.add_scalar('test/avg_top_percentile', avg_top_percentile, epoch_ind)
            writer.add_scalar('test/avg_num_zeros', avg_num_zeros, epoch_ind)
            model.train()

    return train_mses, train_losses, test_mses, test_losses
Ejemplo n.º 13
0
def prepare_model_parts(train_loader):
    global args, scheduler

    # Load specific checkpoint to continue training
    ckpt = {}
    if args.pixelsnail_ckpt is not None:
        ckpt = torch.load(args.pixelsnail_ckpt)
        args = ckpt['args']

    # Create PixelSnail object
    if args.hier == 'top':
        model = FistaPixelSNAIL(
            [args.size // 8, args.size // 8],
            512,
            args.pixelsnail_channel,
            5,
            4,
            args.pixelsnail_n_res_block,
            args.pixelsnail_n_res_channel,
            dropout=args.pixelsnail_dropout,
            n_out_res_block=args.pixelsnail_n_out_res_block,
        )

    elif args.hier == 'bottom':
        model = FistaPixelSNAIL(
            [args.size // 4, args.size // 4],
            512,
            args.pixelsnail_channel,
            5,
            4,
            args.pixelsnail_n_res_block,
            args.pixelsnail_n_res_channel,
            attention=False,
            dropout=args.pixelsnail_dropout,
            n_cond_res_block=args.pixelsnail_n_cond_res_block,
            cond_res_channel=args.pixelsnail_n_res_channel,
        )

    # Load saved checkpoint into PixelSnail object
    if 'model' in ckpt:
        model.load_state_dict(ckpt['model'])

    # Parallelize training
    model = nn.DataParallel(model)

    # Move model to proper device
    model = model.to(args.device)

    # Create other training objects
    optimizer = optim.Adam(model.parameters(), lr=args.pixelsnail_lr)
    if amp is not None:
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp)

    scheduler = None
    if args.pixelsnail_sched == 'cycle':
        scheduler = CycleScheduler(optimizer,
                                   args.pixelsnail_lr,
                                   n_iter=len(train_loader) *
                                   args.pixelsnail_epoch,
                                   momentum=None)
    return model, optimizer
Ejemplo n.º 14
0
    def run(self, args):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        transform = [transforms.ToTensor()]

        if args.normalize:
            transform.append(
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]))

        transform = transforms.Compose(transform)

        dataset = datasets.ImageFolder(args.path, transform=transform)
        sampler = dist_fn.data_sampler(dataset,
                                       shuffle=True,
                                       distributed=args.distributed)
        loader = DataLoader(dataset,
                            batch_size=args.batch_size // args.n_gpu,
                            sampler=sampler,
                            num_workers=args.num_workers)

        self = self.to(device)

        if args.distributed:
            self = nn.parallel.DistributedDataParallel(
                self,
                device_ids=[dist_fn.get_local_rank()],
                output_device=dist_fn.get_local_rank())

        optimizer = args.optimizer(self.parameters(), lr=args.lr)
        scheduler = None
        if args.sched == 'cycle':
            scheduler = CycleScheduler(
                optimizer,
                args.lr,
                n_iter=len(loader) * args.epoch,
                momentum=None,
                warmup_proportion=0.05,
            )

        start = str(time())
        run_path = os.path.join('runs', start)
        sample_path = os.path.join(run_path, 'sample')
        checkpoint_path = os.path.join(run_path, 'checkpoint')
        os.mkdir(run_path)
        os.mkdir(sample_path)
        os.mkdir(checkpoint_path)

        with Progress() as progress:
            train = progress.add_task(f'epoch 1/{args.epoch}',
                                      total=args.epoch,
                                      columns='epochs')
            steps = progress.add_task('',
                                      total=len(dataset) // args.batch_size)

            for epoch in range(args.epoch):
                progress.update(steps, completed=0, refresh=True)

                for recon_loss, latent_loss, avg_mse, lr in self.train_epoch(
                        epoch, loader, optimizer, scheduler, device,
                        sample_path):
                    progress.update(
                        steps,
                        description=
                        f'mse: {recon_loss:.5f}; latent: {latent_loss:.5f}; avg mse: {avg_mse:.5f}; lr: {lr:.5f}'
                    )
                    progress.advance(steps)

                if dist_fn.is_primary():
                    torch.save(
                        self.state_dict(),
                        os.path.join(checkpoint_path,
                                     f'vqvae_{str(epoch + 1).zfill(3)}.pt'))

                progress.update(train,
                                description=f'epoch {epoch + 1}/{args.epoch}')
                progress.advance(train)
Ejemplo n.º 15
0
            n_cond_res_block=conf['pixelsnail']['n_cond_res_block'],
            cond_res_channel=conf['pixelsnail']['n_res_channel'])
        if conf['pixelsnail']['load']:
            weights = torch.load(
                path_weights / 'pixelsnail_1' /
                conf['pixelsnail']['name_bottom'])['state_dict']
            model.load_state_dict(weights)

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=conf['pixelsnail']['lr'])

    scheduler = None
    if conf['pixelsnail']['sched'] == 'cycle':
        scheduler = CycleScheduler(optimizer,
                                   conf['pixelsnail']['lr'],
                                   n_iter=len(train_loader) *
                                   conf['pixelsnail']['epochs'],
                                   momentum=None)

    writer = SummaryWriter(
        log_dir=f"../tensorboard/{conf['experiment']}_{hier}/")

    for i in range(conf['pixelsnail']['epochs']):
        train_loss = train(hier, i, train_loader, model, optimizer, scheduler,
                           device)
        test_loss = test(hier, i, test_loader, model, device)
        writer.add_scalar('PixelSNAIL/Train Loss', train_loss, i)
        writer.add_scalar('PixelSNAIL/Test Loss', test_loss, i)

        saveModel(path_weights, hier, model, optimizer, i)