예제 #1
0
def train_vae():

    batch_size = 64
    epochs = 1000
    latent_dimension = 100
    patience = 10

    device = torch.device(
        'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    # load data
    train_loader, valid_loader, _ = get_data_loader('data', batch_size)

    model = VAE(latent_dimension).to(device)

    optim = Adam(model.parameters(), lr=1e-3)

    val_greater_count = 0
    last_val_loss = 0
    for e in range(epochs):
        running_loss = 0
        model.train()
        for i, (images, _) in enumerate(train_loader):
            images = images.to(device)
            model.zero_grad()
            outputs, mu, logvar = model(images)
            loss = compute_loss(images, outputs, mu, logvar)
            running_loss += loss
            loss.backward()
            optim.step()

        running_loss = running_loss / len(train_loader)
        model.eval()
        with torch.no_grad():
            val_loss = 0
            for images, _ in valid_loader:
                images = images.to(device)
                outputs, mu, logvar = model(images)
                loss = compute_loss(images, outputs, mu, logvar)
                val_loss += loss
            val_loss /= len(valid_loader)

        if val_loss > last_val_loss:
            val_greater_count += 1
        else:
            val_greater_count = 0
        last_val_loss = val_loss

        torch.save(
            {
                'epoch': e,
                'model': model.state_dict(),
                'running_loss': running_loss,
                'optim': optim.state_dict(),
            }, "vae/upsample_checkpoint_{}.pth".format(e))
        print("Epoch: {} Train Loss: {}".format(e + 1, running_loss.item()))
        print("Epoch: {} Val Loss: {}".format(e + 1, val_loss.item()))
        if val_greater_count >= patience:
            break
예제 #2
0
def speed_bench():

    num_iters = 30
    bs = 1

    log_str = ("    {:8} [{:3d}/{:3d}] "
               " Speed: {:.1f} imgs/sec ({:.3f} sec/batch)")

    conf = loader.default_conf.copy()
    conf['num_worker'] = 8

    myloader = loader.get_data_loader(
        conf=conf, batch_size=1, pin_memory=False)

    start_time = time.time()

    for step, sample in enumerate(myloader):

        if step == num_iters:
            break

        logging.info("Processed example: {}".format(step))

    duration = time.time() - start_time
    logging.info("Loading {} examples took: {}".format(num_iters, duration))

    duration = duration / num_iters
    imgs_per_sec = bs / duration
    for_str = log_str.format(
        "Bench", 1, 2,
        imgs_per_sec, duration)
    logging.info(for_str)

    start_time = time.time()

    for step, sample in enumerate(myloader):

        if step == num_iters:
            break

    duration = time.time() - start_time
    logging.info("Loading another {} examples took: {}".format(
        num_iters, duration))

    duration = duration / num_iters
    imgs_per_sec = bs / duration
    for_str = log_str.format(
        "Bench", 2, 2,
        imgs_per_sec, duration)
    logging.info(for_str)
예제 #3
0
def test_scatter_plot_2d():
    conf = loader.default_conf.copy()
    conf['label_encoding'] = 'spatial_2d'
    conf['grid_dims'] = 2
    conf['grid_size'] = 10
    myloader = loader.get_data_loader(conf=conf,
                                      batch_size=6,
                                      pin_memory=False,
                                      split='val')

    batch = next(myloader.__iter__())
    myvis = visualizer.LocalSegVisualizer(class_file=class_file, conf=conf)

    label = batch['label'][0].numpy()
    prediction = np.random.random((label.shape)) - 0.5 + label

    myvis.scatter_plot(label=label, prediction=prediction)
예제 #4
0
def test_plot_batch_2d():
    conf = loader.default_conf.copy()
    conf['label_encoding'] = 'spatial_2d'
    conf['grid_dims'] = 2
    conf['grid_size'] = 10
    myloader = loader.get_data_loader(conf=conf,
                                      batch_size=6,
                                      pin_memory=False,
                                      split='val')
    batch = next(myloader.__iter__())

    myvis = visualizer.LocalSegVisualizer(class_file=class_file, conf=conf)
    start_time = time.time()

    return
    myvis.plot_batch(batch)
    duration = time.time() - start_time

    logging.info("Visualizing one batch took {} seconds".format(duration))
예제 #5
0
def test_plot_batch(verbose=False):
    conf = loader.default_conf.copy()
    conf['dataset'] = 'blender_mini'

    return

    myloader = loader.get_data_loader(conf=conf,
                                      batch_size=6,
                                      pin_memory=False,
                                      split='train')
    batch = next(myloader.__iter__())

    myvis = visualizer.LocalSegVisualizer(class_file=class_file, conf=conf)
    if verbose:
        start_time = time.time()
        myvis.plot_batch(batch)
        duration = time.time() - start_time

    logging.info("Visualizing one batch took {} seconds".format(duration))
예제 #6
0
def test_loading():

    conf = loader.default_conf.copy()
    conf['num_worker'] = 8

    myloader = loader.get_data_loader(
        conf=conf, batch_size=1, pin_memory=False)

    start_time = time.time()

    for step, sample in enumerate(myloader):

        if step == 10:
            break

        logging.info("Processed example: {}".format(step))

    duration = time.time() - start_time

    logging.info("Loading 10 examples took: {}".format(duration))
예제 #7
0
def test_loading_2d():

    conf = loader.default_conf.copy()
    conf['num_worker'] = 8
    conf['label_encoding'] = 'spatial_2d'
    conf['grid_dims'] = 2
    conf['grid_size'] = 10

    myloader = loader.get_data_loader(
        conf=conf, batch_size=1, pin_memory=False)

    start_time = time.time()

    for step, sample in enumerate(myloader):

        if step == 2:
            break

        logging.info("Processed example: {}".format(step))

    duration = time.time() - start_time

    logging.info("Loading 10 examples took: {}".format(duration))
예제 #8
0
def test_loading_blender(verbose=False):

    conf = loader.default_conf.copy()
    conf["dataset"] = "blender_mini"
    conf['num_worker'] = 8

    # conf['transform'] = loader.mytransform

    myloader = loader.get_data_loader(
        conf=conf, batch_size=8, pin_memory=False)

    for step, sample in enumerate(myloader):

        myvis = visualizer.LocalSegVisualizer(
            class_file=conf["vis_file"], conf=conf)
        start_time = time.time()
        myvis.plot_batch(sample)
        duration = time.time() - start_time # NOQA

        if step == 5:
            break

        if verbose:
            plt.show()
예제 #9
0
        default=False,
        action='store_true',
        help='whether to display debug information (default False)')
    args = parser.parse_args()
    print(args)

    # select device (cuda or cpu)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    # set seeds for reproducibility
    torch.manual_seed(1)
    np.random.seed(1)

    # data loaders and class count
    trainloader, validloader, testloader, args.n_classes = get_data_loader(
        args.dataset, args.batchsize)
    try:
        image_size = trainloader.batch_sampler.sampler.datasource[0][0].shape
    except AttributeError:
        image_size = trainloader.batch_sampler.sampler.data_source[0][0].shape

    # prefix for saved model and directory names
    model_prefix = args.dataset + '_'

    # classifier to extract features (for fid score computation)
    try:
        classifier = torch.load('classifier.pt', map_location='cpu')
        classifier.eval()
        print('Classifier loaded!')
    except FileNotFoundError:
        classifier = Classifier()
예제 #10
0
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("There are", num_params, "parameters in this model")

    print("Use %s transformer for training" % args.transform)
    if args.transform == "basic":
        train_transform = valid_transform = answer.basic_transformer
    elif args.transform == "norm":
        train_transform = answer.norm_transformer
        valid_transform = answer.norm_transformer
    elif args.transform == "aug":
        train_transform = answer.aug_transformer
        valid_transform = answer.norm_transformer

    trainloader, validloader = loader.get_data_loader(train_transform,
                                                      valid_transform,
                                                      args.batch_size)

    use_cuda = torch.cuda.is_available() and args.use_cuda

    if use_cuda:
        model = model.cuda()

    # %%
    train_losses = []
    valid_losses = []
    train_accs = []
    valid_accs = []
    for epoch in range(args.epoch):  # loop over the dataset multiple times
        learning_rate = 0.01 * 0.8**epoch
        learning_rate = max(learning_rate, 1e-6)
예제 #11
0
def train_gan():

    batch_size = 64
    epochs = 100
    disc_update = 1
    gen_update = 5
    latent_dimension = 100
    lambduh = 10

    device = torch.device(
        'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    # load data
    train_loader, valid_loader, test_loader = get_data_loader(
        'data', batch_size)

    disc_model = Discriminator().to(device)
    gen_model = Generator(latent_dimension).to(device)

    disc_optim = Adam(disc_model.parameters(), lr=1e-4, betas=(0.5, 0.9))
    gen_optim = Adam(gen_model.parameters(), lr=1e-4, betas=(0.5, 0.9))

    for e in range(epochs):
        disc_loss = 0
        gen_loss = 0
        for i, (images, _) in enumerate(train_loader):
            images = images.to(device)
            b_size = images.shape[0]
            step = i + 1
            if step % disc_update == 0:
                disc_model.zero_grad()
                # sample noise
                noise = torch.randn((b_size, latent_dimension), device=device)

                # loss on fake
                inputs = gen_model(noise).detach()
                f_outputs = disc_model(inputs)
                loss = f_outputs.mean()

                # loss on real
                r_outputs = disc_model(images)
                loss -= r_outputs.mean()

                # add gradient penalty
                loss += lambduh * gradient_penalty(disc_model, images, inputs,
                                                   device)

                disc_loss += loss
                loss.backward()
                disc_optim.step()

            if step % gen_update == 0:
                gen_model.zero_grad()

                noise = torch.randn((b_size, latent_dimension)).to(device)
                inputs = gen_model(noise)
                outputs = disc_model(inputs)
                loss = -outputs.mean()

                gen_loss += loss
                loss.backward()
                gen_optim.step()

        torch.save(
            {
                'epoch': e,
                'disc_model': disc_model.state_dict(),
                'gen_model': gen_model.state_dict(),
                'disc_loss': disc_loss,
                'gen_loss': gen_loss,
                'disc_optim': disc_optim.state_dict(),
                'gen_optim': gen_optim.state_dict()
            }, "upsample/checkpoint_{}.pth".format(e))
        print("Epoch: {} Disc loss: {}".format(
            e + 1,
            disc_loss.item() / len(train_loader)))
        print("Epoch: {} Gen loss: {}".format(
            e + 1,
            gen_loss.item() / len(train_loader)))