コード例 #1
0
def val(epoch, loader, model, optimizer, scheduler, device, img_size):

    model.eval()
    loader = tqdm(loader)
    criterion = nn.MSELoss()
    criterion.to(device)

    sample_size = 8

    loss_sum = 0
    loss_n = 0

    with torch.no_grad():
        for ep_idx, (episode, _) in enumerate(loader):

            for i in range(episode.shape[1]):

                # Get, resize and reconstruct current batch of images
                img = episode[:, i]
                img = util.resize_img(img, type='rgb', size=img_size)
                img = img.to(device)

                out, latent_loss, _ = model(img)

                recon_loss = criterion(out, img)
                latent_loss = latent_loss.mean()

                loss_sum += recon_loss.item() * img.shape[0]
                loss_n += img.shape[0]

                lr = optimizer.param_groups[0]['lr']

                loader.set_description((
                    f'validation; mse: {recon_loss.item():.5f}; '
                    f'latent: {latent_loss.item():.3f}; avg mse: {loss_sum/loss_n:.5f}; '
                    f'lr: {lr:.5f}'))

    model.train()

    return loss_sum / loss_n
コード例 #2
0
def val(epoch, loader, model, optimizer, scheduler, device, img_size):

    model.eval()
    loader = tqdm(loader)
    sample_size = 8

    criterion = nn.BCELoss(reduction='none').to(device)
    bce_sum = 0
    bce_n = 0

    with torch.no_grad():
        for ep_idx, (episode, _) in enumerate(loader):
            for i in range(episode.shape[1]):

                img = episode[:, i]
                img = util.resize_img(img, type='seg', size=img_size)
                img = img.to(device)

                bce_weight = util.loss_weights(img).to(device)
                img = util.one_hot(img, device=device)

                out, latent_loss, _ = model(img)

                recon_loss = criterion(out, img)
                recon_loss = (recon_loss * bce_weight).mean()
                latent_loss = latent_loss.mean()

                bce_sum += recon_loss.item() * img.shape[0]
                bce_n += img.shape[0]

                lr = optimizer.param_groups[0]['lr']

                loader.set_description((
                    f'validation; bce: {recon_loss.item():.5f}; '
                    f'latent: {latent_loss.item():.3f}; avg bce: {bce_sum/bce_n:.5f}; '
                    f'lr: {lr:.5f}'))

    model.train()
    return bce_sum / bce_n
コード例 #3
0
def extract(loader, model, device, img_size, img_type, destination):

    loader = tqdm(loader)

    for ep_idx, (episode, _) in enumerate(loader):
        sequence = None

        for i in range(episode.shape[1]):
            img = episode[:, i]
            img = util.resize_img(img, type=img_type, size=img_size)
            img = img.to(device)
            if img_type == 'seg': img = util.one_hot(img, device=device)

            with torch.no_grad():
                _, _, code = model.encode(img)

            if i == 0:
                sequence = code.reshape(1, img.shape[0], -1)
            else:
                sequence = torch.cat(
                    (sequence, code.reshape(1, img.shape[0], -1)), 0)

        sequence = sequence.cpu().numpy()
        np.save(destination + f'/episode_{ep_idx}', sequence)
コード例 #4
0
def train(epoch, loader_rgb, loader_seg, model, optimizer, scheduler, device,
          img_size):
    loader_rgb = tqdm(loader_rgb)
    model.train()
    criterion = nn.MSELoss(reduction='none')
    #criterion = pytorch_ssim.SSIM3D(window_size = 33)
    criterion.to(device)

    latent_loss_weight = 0.25
    sample_size = 8

    loss_sum = 0
    loss_n = 0

    seg_iter = iter(loader_seg)
    for ep_idx, (episode, _) in enumerate(loader_rgb):
        episode_seg, _ = next(seg_iter)

        # Since we don't shuffle the dataloaders, we create and shuffle an order
        # to be used for both the rgb and segmentation episodes
        order = np.arange(episode.shape[1])
        np.random.shuffle(order)

        for i in range(episode.shape[1]):
            model.zero_grad()

            # Get and resize current batch of images
            img = episode[:, order[i]]
            img = util.resize_img(img, type='rgb', size=img_size)
            img = img.to(device)

            # Getting the bce_weights
            img_seg = episode_seg[:, order[i]]
            img_seg = util.resize_img(img_seg, type='seg',
                                      size=img_size).to(device)
            class_weights = util.seg_weights(img_seg).to(device)

            out, latent_loss, top = model(img)

            # Calculate loss
            recon_loss = criterion(out, img)
            recon_loss = (recon_loss * class_weights).mean()
            latent_loss = latent_loss.mean()
            loss = recon_loss + latent_loss_weight * latent_loss

            loss.backward()
            optimizer.step()

            recon_loss_item = recon_loss.item()
            latent_loss_item = latent_loss.item()
            loss_sum += recon_loss_item * img.shape[0]
            loss_n += img.shape[0]

            lr = optimizer.param_groups[0]['lr']

            loader_rgb.set_description((
                f'epoch: {epoch + 1}; mse: {recon_loss_item:.5f}; '
                f'latent: {latent_loss_item:.3f}; avg mse: {loss_sum / loss_n:.5f}; '
                f'lr: {lr:.5f}; '))

            if i == 100:
                model.eval()
                sample = img[:sample_size]

                with torch.no_grad():

                    out, _, top = model(sample)

                utils.save_image(
                    torch.cat([sample, out], 0),
                    f'sample/{str(epoch + 1).zfill(4)}_{img_size}x{img_size}_GAN.png',
                    nrow=sample_size,
                )
                model.train()
コード例 #5
0
def train(epoch, loader, model, optimizer, scheduler, device, img_size):

    loader = tqdm(loader)
    model.train()

    criterion = nn.BCELoss(reduction='none')
    criterion.to(device)

    latent_loss_weight = 0.25
    sample_size = 8

    bce_sum = 0
    bce_n = 0

    for ep_idx, (episode, _) in enumerate(loader):
        for i in range(episode.shape[1]):

            model.zero_grad()
            # Get, resize and one-hot encode current batch of images
            img = episode[:, i]
            img = util.resize_img(img, type='seg', size=img_size)
            img = img.to(device)
            #bce_weight = util.loss_weights(img).to(device)
            bce_weight = util.seg_weights(img, out_channel=13).to(device)
            img = util.one_hot(img, device=device)

            out, latent_loss, _ = model(img)

            recon_loss = criterion(out, img)
            recon_loss = (recon_loss * bce_weight).mean()
            latent_loss = latent_loss.mean()

            loss = recon_loss + latent_loss_weight * latent_loss

            loss.backward()
            if scheduler is not None:
                scheduler.step()
            optimizer.step()

            recon_loss_item = recon_loss.item()
            latent_loss_item = latent_loss.item()
            bce_sum += recon_loss.item() * img.shape[0]
            bce_n += img.shape[0]

            lr = optimizer.param_groups[0]['lr']

            loader.set_description((
                f'epoch: {epoch + 1}; bce: {recon_loss_item:.5f}; '
                f'latent: {latent_loss_item:.3f}; avg bce: {bce_sum/bce_n:.5f}; '
                f'lr: {lr:.5f}; '))

            if i % 100 == 0:
                model.eval()
                sample = img[:sample_size]

                with torch.no_grad():
                    out, _, _ = model(sample)
                    #print('id[0]: ', id[0])
                # Convert one-hot semantic segmentation to RGB
                sample = util.seg_to_rgb(sample)
                out = util.seg_to_rgb(out)

                utils.save_image(
                    torch.cat([sample, out], 0),
                    f'sample/seg/{str(epoch + 1).zfill(4)}_{img_size}x{img_size}.png',
                    nrow=sample_size,
                )
                model.train()
コード例 #6
0
def predict(loader, model_rnn, model_vqvae, args):

    model_rnn.eval()
    model_vqvae.eval()

    start = 75 * 3  #0 for original

    with torch.no_grad():
        for ep_idx, (episode, _) in enumerate(loader):
            sequence_top = None

            for i in range(start, (args.in_steps + args.steps) * 3 + start, 3):
                img = episode[:, i].to(args.device)
                img = util.resize_img(img,
                                      type=args.img_type,
                                      size=args.img_size)

                if args.img_type == 'seg':
                    img = util.one_hot(img, device=args.device)

                if i == start:
                    seq_in = img
                else:
                    seq_in = torch.cat((seq_in, img))

                _, _, top = model_vqvae.encode(img.to(args.device))

                if i == start:
                    sequence_top = top.reshape(1, img.shape[0], -1)
                else:
                    sequence_top = torch.cat(
                        (sequence_top, top.reshape(1, img.shape[0], -1)), 0)

            seq_len = sequence_top.shape[0]

            inputs_top = sequence_top.long()
            inputs_top = F.one_hot(inputs_top,
                                   num_classes=args.n_embed).float()
            inputs_top = inputs_top.view(-1, img.shape[0],
                                         args.n_embed * 64)  # 16

            # Forward pass through the rnn
            out, hidden = model_rnn(inputs_top[:args.in_steps])

            # Reshape, argmax and prepare for decoding image
            out = out[-1].unsqueeze(0)
            out_top = out.view(-1, args.n_embed, 64)  # 16
            out_top = torch.argmax(out_top, dim=1)
            out_top_seq = out_top.view(-1, 8, 8)  # 4,4

            out = out_top
            for t in range(args.steps - 1):
                # One-hot encode previous prediction
                out = out.long()
                out = F.one_hot(out, num_classes=args.n_embed).float()
                out = out.view(-1, img.shape[0], args.n_embed * 64)  # 16
                # Predict next frame
                out, hidden = model_rnn(out, hidden=hidden)
                # Argmax and save
                out = out.view(-1, args.n_embed, 64)  # 16
                out = torch.argmax(out, dim=1)
                out_top_seq = torch.cat((out_top_seq, out.view(-1, 8, 8)),
                                        0)  # 4,4

            decoded_samples = model_vqvae.decode_code(out_top_seq)  # old vqvae

            channels = 13 if args.img_type == 'seg' else 3
            seq_out = torch.zeros(args.in_steps, channels, args.img_size,
                                  args.img_size).to(device)
            #print('seq_out: ', seq_out.shape, 'decoded_samples: ', decoded_samples[0].shape)
            seq_out = torch.cat((seq_out, decoded_samples), 0)

            sequence = torch.cat(
                (seq_in.to(args.device), seq_out.to(args.device)))
            sequence_rgb = util.seg_to_rgb(
                sequence) if args.img_type == 'seg' else sequence

            ########## save images and measure IoU ############
            if args.save_images is True:
                save_individual_images(sequence_rgb, ep_idx, args)
                utils.save_image(
                    sequence_rgb,
                    f'predictions/test_pred_{ep_idx}.png',
                    nrow=(args.in_steps + args.steps),
                )