Ejemplo n.º 1
0
            gt_pos = p_gt[step_id].unsqueeze(0)
            if use_gpu:
                gt_pos = gt_pos.cuda()
            pred_pos = torch.cat([pred_pos, gt_pos[:, n_particle:]], 1)

            # gt_motion_norm (normalized): B x (n_p + n_s) x state_dim
            # pred_motion_norm (normalized): B x (n_p + n_s) x state_dim
            gt_motion = (p_gt[step_id] - p_gt[step_id - 1]).unsqueeze(0)
            if use_gpu:
                gt_motion = gt_motion.cuda()
            mean_d, std_d = model.stat[2:]
            gt_motion_norm = (gt_motion - mean_d) / std_d
            pred_motion_norm = torch.cat(
                [pred_motion_norm, gt_motion_norm[:, n_particle:]], 1)

            loss_cur = F.l1_loss(pred_motion_norm[:, :n_particle],
                                 gt_motion_norm[:, :n_particle])
            loss_cur_raw = F.l1_loss(pred_pos, gt_pos)

            loss += loss_cur
            loss_raw += loss_cur_raw
            loss_counter += 1

            # state_cur (unnormalized): B x n_his x (n_p + n_s) x state_dim
            state_cur = torch.cat([state_cur[:, 1:], pred_pos.unsqueeze(1)], 1)
            state_cur = state_cur.detach()[0]

            # record the prediction
            p_pred[step_id] = state_cur[-1].detach().cpu()
    '''
    print loss
    '''
Ejemplo n.º 2
0
def train(args, train_loader, pose_net, optimizer, train_writer, num_sample):
    global n_iter
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter(precision=4)
    # switch to train mode
    pose_net.train()

    end = time.time()
    for i, (tgt_img, ref_imgs, ref_poses, intrinsics, intrinsics_inv,
            tgt_depth, ref_depths, ref_noise_poses,
            initial_pose) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)
        tgt_img_var = Variable(tgt_img.cuda())
        ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]
        ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses]
        ref_noise_poses_var = [
            Variable(pose.cuda()) for pose in ref_noise_poses
        ]
        initial_pose_var = Variable(initial_pose.cuda())

        ref_depths_var = [Variable(dep.cuda()) for dep in ref_depths]
        intrinsics_var = Variable(intrinsics.cuda())
        intrinsics_inv_var = Variable(intrinsics_inv.cuda())
        tgt_depth_var = Variable(tgt_depth.cuda())
        pose = torch.cat(ref_poses_var, 1)

        noise_pose = torch.cat(ref_noise_poses_var, 1)

        pose_norm = torch.norm(noise_pose[:, :, :3, 3], dim=-1,
                               keepdim=True)  # b * n* 1

        p_angle, p_trans, rot_c, trans_c = pose_net(tgt_img_var,
                                                    ref_imgs_var,
                                                    initial_pose_var,
                                                    noise_pose,
                                                    intrinsics_var,
                                                    intrinsics_inv_var,
                                                    tgt_depth_var,
                                                    ref_depths_var,
                                                    trans_norm=pose_norm)

        batch_size = p_angle.shape[0]
        p_angle_v = torch.sum(
            F.softmax(p_angle, dim=1).view(batch_size, -1, 1) * rot_c, dim=1)
        p_trans_v = torch.sum(
            F.softmax(p_trans, dim=1).view(batch_size, -1, 1) * trans_c, dim=1)
        p_matrix = Variable(torch.zeros((batch_size, 4, 4)).float()).cuda()
        p_matrix[:, 3, 3] = 1
        p_matrix[:, :3, :] = torch.cat(
            [angle2matrix(p_angle_v),
             p_trans_v.unsqueeze(-1)], dim=-1)  # 2*3*4
        loss = 0.
        loss_rot = 0.
        loss_trans = 0.
        for j in range(len(ref_imgs)):
            exp_pose = torch.matmul(inv(pose[:, j]), noise_pose[:, j])
            gt_angle = matrix2angle(exp_pose[:, :3, :3])
            gt_trans = exp_pose[:, :3, 3]

            loss_rot = F.l1_loss(p_angle_v, gt_angle) * 50
            loss_trans = F.l1_loss((p_trans_v / pose_norm[:, :, 0]),
                                   (gt_trans / pose_norm[:, :, 0])) * 50

            loss = loss + loss_trans + loss_rot

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

        if i > 0 and n_iter % args.print_freq == 0:
            train_writer.add_scalar('total_loss', loss.item(), n_iter)

        if n_iter > 0 and n_iter % 2000 == 0:
            save_checkpoint(args.save_path, {
                'epoch': n_iter + 1,
                'state_dict': pose_net.module.state_dict()
            }, n_iter)

        # record loss and EPE
        losses.update(loss.data[0], batch_size)

        batch_time.update(time.time() - end)
        end = time.time()

        with open(args.save_path / args.log_full, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([loss.data[0]])
        # import pdb;pdb.set_trace()
        if i % args.print_freq == 0:

            print(
                'Train {}: Time {} Data {} Loss: {:.4f} rot: {:.4f}trans: {:.4f}' \
                    .format(i, batch_time, data_time, loss.data[0], loss_rot.data[0],
                            loss_trans.data[0]))
        n_iter += 1
    return losses.avg[0]
Ejemplo n.º 3
0
def l1_loss(pred, target):
    """Warpper of mse loss."""
    return F.l1_loss(pred, target, reduction='none')
def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
          backup_every: int, force_restart: bool, hparams):

    syn_dir = Path(syn_dir)
    models_dir = Path(models_dir)
    models_dir.mkdir(exist_ok=True)

    model_dir = models_dir.joinpath(run_id)
    plot_dir = model_dir.joinpath("plots")
    wav_dir = model_dir.joinpath("wavs")
    mel_output_dir = model_dir.joinpath("mel-spectrograms")
    meta_folder = model_dir.joinpath("metas")
    model_dir.mkdir(exist_ok=True)
    plot_dir.mkdir(exist_ok=True)
    wav_dir.mkdir(exist_ok=True)
    mel_output_dir.mkdir(exist_ok=True)
    meta_folder.mkdir(exist_ok=True)

    weights_fpath = model_dir.joinpath(run_id).with_suffix(".pt")
    metadata_fpath = syn_dir.joinpath("train.txt")

    print("Checkpoint path: {}".format(weights_fpath))
    print("Loading training data from: {}".format(metadata_fpath))
    print("Using model: Tacotron")

    # Book keeping
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)

    # From WaveRNN/train_tacotron.py
    if torch.cuda.is_available():
        device = torch.device("cuda")

        for session in hparams.tts_schedule:
            _, _, _, batch_size = session
            if batch_size % torch.cuda.device_count() != 0:
                raise ValueError(
                    "`batch_size` must be evenly divisible by n_gpus!")
    else:
        device = torch.device("cpu")
    print("Using device:", device)

    # Instantiate Tacotron Model
    print("\nInitialising Tacotron Model...\n")
    model = Tacotron(
        embed_dims=hparams.tts_embed_dims,
        num_chars=len(symbols),
        encoder_dims=hparams.tts_encoder_dims,
        decoder_dims=hparams.tts_decoder_dims,
        n_mels=hparams.num_mels,
        fft_bins=hparams.num_mels,
        postnet_dims=hparams.tts_postnet_dims,
        encoder_K=hparams.tts_encoder_K,
        lstm_dims=hparams.tts_lstm_dims,
        postnet_K=hparams.tts_postnet_K,
        num_highways=hparams.tts_num_highways,
        dropout=hparams.tts_dropout,
        stop_threshold=hparams.tts_stop_threshold,
        speaker_embedding_size=hparams.speaker_embedding_size).to(device)

    # Initialize the optimizer
    optimizer = optim.Adam(model.parameters())

    # Load the weights
    if force_restart or not weights_fpath.exists():
        print("\nStarting the training of Tacotron from scratch\n")
        model.save(weights_fpath)

        # Embeddings metadata
        char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv")
        with open(char_embedding_fpath, "w", encoding="utf-8") as f:
            for symbol in symbols:
                if symbol == " ":
                    symbol = "\\s"  # For visual purposes, swap space with \s

                f.write("{}\n".format(symbol))

    else:
        print("\nLoading weights at %s" % weights_fpath)
        model.load(weights_fpath, optimizer)
        print("Tacotron weights loaded from step %d" % model.step)

    # Initialize the dataset
    metadata_fpath = syn_dir.joinpath("train.txt")
    mel_dir = syn_dir.joinpath("mels")
    embed_dir = syn_dir.joinpath("embeds")
    dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
    test_loader = DataLoader(dataset,
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True)

    for i, session in enumerate(hparams.tts_schedule):
        current_step = model.get_step()

        r, lr, max_step, batch_size = session

        training_steps = max_step - current_step

        # Do we need to change to the next session?
        if current_step >= max_step:
            # Are there no further sessions than the current one?
            if i == len(hparams.tts_schedule) - 1:
                # We have completed training. Save the model and exit
                model.save(weights_fpath, optimizer)
                break
            else:
                # There is a following session, go to it
                continue

        model.r = r

        # Begin the training
        simple_table([(f"Steps with r={r}",
                       str(training_steps // 1000) + "k Steps"),
                      ("Batch Size", batch_size), ("Learning Rate", lr),
                      ("Outputs/Step (r)", model.r)])

        for p in optimizer.param_groups:
            p["lr"] = lr

        data_loader = DataLoader(
            dataset,
            collate_fn=lambda batch: collate_synthesizer(batch, r, hparams),
            batch_size=batch_size,
            num_workers=0,
            shuffle=True,
            pin_memory=True)

        total_iters = len(dataset)
        steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32)
        epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32)

        for epoch in range(1, epochs + 1):
            for i, (texts, mels, embeds, idx) in enumerate(data_loader, 1):
                start_time = time.time()

                # Generate stop tokens for training
                stop = torch.ones(mels.shape[0], mels.shape[2])
                for j, k in enumerate(idx):
                    stop[j, :int(dataset.metadata[k][4]) - 1] = 0

                texts = texts.to(device)
                mels = mels.to(device)
                embeds = embeds.to(device)
                stop = stop.to(device)

                # Forward pass
                # Parallelize model onto GPUS using workaround due to python bug
                if device.type == "cuda" and torch.cuda.device_count() > 1:
                    m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(
                        model, texts, mels, embeds)
                else:
                    m1_hat, m2_hat, attention, stop_pred = model(
                        texts, mels, embeds)

                # Backward pass
                m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
                m2_loss = F.mse_loss(m2_hat, mels)
                stop_loss = F.binary_cross_entropy(stop_pred, stop)

                loss = m1_loss + m2_loss + stop_loss

                optimizer.zero_grad()
                loss.backward()

                if hparams.tts_clip_grad_norm is not None:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), hparams.tts_clip_grad_norm)
                    if np.isnan(grad_norm.cpu()):
                        print("grad_norm was NaN!")

                optimizer.step()

                time_window.append(time.time() - start_time)
                loss_window.append(loss.item())

                step = model.get_step()
                k = step // 1000

                msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | {1./time_window.average:#.2} steps/s | Step: {k}k | "
                stream(msg)

                # Backup or save model as appropriate
                if backup_every != 0 and step % backup_every == 0:
                    backup_fpath = Path("{}/{}_{}k.pt".format(
                        str(weights_fpath.parent), run_id, k))
                    model.save(backup_fpath, optimizer)

                if save_every != 0 and step % save_every == 0:
                    # Must save latest optimizer state to ensure that resuming training
                    # doesn't produce artifacts
                    model.save(weights_fpath, optimizer)

                # Evaluate model to generate samples
                epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch  # If epoch is done
                step_eval = hparams.tts_eval_interval > 0 and step % hparams.tts_eval_interval == 0  # Every N steps
                if epoch_eval or step_eval:
                    for sample_idx in range(hparams.tts_eval_num_samples):
                        # At most, generate samples equal to number in the batch
                        if sample_idx + 1 <= len(texts):
                            # Remove padding from mels using frame length in metadata
                            mel_length = int(
                                dataset.metadata[idx[sample_idx]][4])
                            mel_prediction = np_now(
                                m2_hat[sample_idx]).T[:mel_length]
                            target_spectrogram = np_now(
                                mels[sample_idx]).T[:mel_length]
                            attention_len = mel_length // model.r

                            eval_model(attention=np_now(
                                attention[sample_idx][:, :attention_len]),
                                       mel_prediction=mel_prediction,
                                       target_spectrogram=target_spectrogram,
                                       input_seq=np_now(texts[sample_idx]),
                                       step=step,
                                       plot_dir=plot_dir,
                                       mel_output_dir=mel_output_dir,
                                       wav_dir=wav_dir,
                                       sample_num=sample_idx + 1,
                                       loss=loss,
                                       hparams=hparams)

                # Break out of loop to update training schedule
                if step >= max_step:
                    break

            # Add line break after every epoch
            print("")
Ejemplo n.º 5
0
 def forward(self, input, target):
     return F.l1_loss(input, target, reduction=self.reduction)
Ejemplo n.º 6
0
def train():
    # v1, all ds
    # v3, stcmds ds alone

    stcmds_ds = dataset.new_stcmds_dataset(root=hp.stcmds_data_root, mel_feature_root=hp.mel_feature_root)
    # aishell_ds = dataset.new_aishell_dataset(root=hp.aishell_data_root, mel_feature_root=hp.mel_feature_root)
    # aidatatang_ds = dataset.new_aidatatang_dataset(root=hp.aidatatang_data_root, mel_feature_root=hp.mel_feature_root)
    # primewords_ds = dataset.new_primewords_dataset(root=hp.primewords_data_root, mel_feature_root=hp.mel_feature_root)
    # datasets = [stcmds_ds, aishell_ds, aidatatang_ds, primewords_ds]
    datasets = [stcmds_ds]
    mds = dataset.MultiAudioDataset(datasets)
    random.shuffle(mds.speakers)
    train_speakers = mds.speakers
    # eval_speakers = mds.speakers[-100:]
    
    ds = dataset.VocoderDataset(train_speakers,
                        utterances_per_speaker=1,
                        seq_len=hp.vocoder_seq_len)
    loader = torch.utils.data.DataLoader(ds,
                                        batch_size=hp.vocoder_batch_size,
                                        shuffle=True,
                                        num_workers=6)

    netG = Generator(hp.num_mels, hp.vocoder_ngf, hp.vocoder_n_residual_layers).cuda()
    netD = Discriminator(hp.vocoder_num_D, hp.vocoder_ndf, hp.vocoder_n_layers_D, hp.vocoder_downsamp_factor).cuda()
    fft = Audio2Mel(n_fft=hp.n_fft,
                    hop_length=hp.hop_length,
                    win_length=hp.win_length,
                    sampling_rate=hp.sample_rate,
                    n_mel_channels=hp.num_mels,
                    mel_fmin=hp.fmin,
                    mel_fmax=hp.fmax,
                    min_level_db=hp.min_level_db).cuda()

    optG = torch.optim.Adam(netG.parameters(), lr=hp.vocoder_G_lr, betas=(0.5, 0.9))
    optD = torch.optim.Adam(netD.parameters(), lr=hp.vocoder_D_lr, betas=(0.5, 0.9))

    total_steps = 0

    ckpts = sorted(list(Path(hp.vocoder_save_dir).glob('*.pt')))
    if len(ckpts) > 0:
        latest_ckpt_path = ckpts[-1]
        ckpt = torch.load(latest_ckpt_path)
        if ckpt:
            logging.info(f'loading vocoder ckpt {latest_ckpt_path}')
            netG.load_state_dict(ckpt['netG_state_dict'])
            netD.load_state_dict(ckpt['netD_state_dict'])
            optG.load_state_dict(ckpt['optG_state_dict'])
            optD.load_state_dict(ckpt['optD_state_dict'])
            total_steps = ckpt['total_steps']

    while True:
        if total_steps >= hp.vocoder_train_steps:
            break

        for segments in loader:
            if total_steps >= hp.vocoder_train_steps:
                break

            x_t = segments.cuda()
            s_t = fft(x_t).detach()
            # print(f's_t.shape {s_t.shape}')
            x_pred_t = netG(s_t.cuda())
            # print(f'x_pred_t {x_pred_t.shape}')

            with torch.no_grad():
                s_pred_t = fft(x_pred_t.detach())
                s_error = F.l1_loss(s_t, s_pred_t).item()

            #######################
            # Train Discriminator #
            #######################
            D_fake_det = netD(x_pred_t.cuda().detach())
            D_real = netD(x_t.cuda())

            loss_D = 0
            for scale in D_fake_det:
                loss_D += F.relu(1 + scale[-1]).mean()

            for scale in D_real:
                loss_D += F.relu(1 - scale[-1]).mean()

            netD.zero_grad()
            loss_D.backward()
            optD.step()

            ###################
            # Train Generator #
            ###################
            D_fake = netD(x_pred_t.cuda())

            loss_G = 0
            for scale in D_fake:
                loss_G += -scale[-1].mean()

            loss_feat = 0
            feat_weights = 4.0 / (hp.vocoder_n_layers_D + 1) # 0.8
            D_weights = 1.0 / hp.vocoder_num_D # 0.33333
            wt = D_weights * feat_weights # 2.666666
            for i in range(hp.vocoder_num_D):
                for j in range(len(D_fake[i]) - 1):
                    print(f'i,j {i},{j} {D_fake[i][j]}')
                    loss_feat += wt * F.l1_loss(D_fake[i][j], D_real[i][j].detach())

            netG.zero_grad()
            (loss_G + 10. * loss_feat).backward()
            optG.step()

            total_steps += 1

            if (total_steps+1) % hp.vocoder_train_print_interval == 0:
                logging.info(f'vocoder step {total_steps+1} loss discriminator {loss_D.item():.3f} generator {loss_G.item():.3f} FM {loss_feat.item():.3f} recon {s_error:.3f}')
            if (total_steps+1) % hp.vocoder_save_interval == 0:
                if not Path(hp.vocoder_save_dir).exists():
                    Path(hp.vocoder_save_dir).mkdir()
                save_path = Path(hp.vocoder_save_dir) / f'{total_steps+1:012d}.pt'
                logging.info(f'saving vocoder ckpt {save_path}')
                torch.save({
                    'netG_state_dict': netG.state_dict(),
                    'netD_state_dict': netD.state_dict(),
                    'optG_state_dict': optG.state_dict(),
                    'optD_state_dict': optD.state_dict(),
                    'total_steps': total_steps
                }, save_path)

                # remove old ckpts
                ckpts = sorted(list(Path(hp.vocoder_save_dir).glob('*.pt')))
                if len(ckpts) > hp.vocoder_max_ckpts:
                    for ckpt in ckpts[:-hp.vocoder_max_ckpts]:
                        Path(ckpt).unlink()
                        logging.info(f'ckpt {ckpt} removed')
Ejemplo n.º 7
0
 def forward(self, input):
     y = self.model(input.x)
     prefix = 'train' if self.training else 'val'
     loss = F.l1_loss(y, input.t)
     ppe.reporting.report({prefix + '/loss': loss})
     return Output(y, loss, input.v)
Ejemplo n.º 8
0
def main():
    args = parse_args()

    root = Path(args.save_path)
    load_root = Path(args.load_path) if args.load_path else None
    root.mkdir(parents=True, exist_ok=True)

    ####################################
    # Dump arguments and create logger #
    ####################################
    with open(root / "args.yml", "w") as f:
        yaml.dump(args, f)
    writer = SummaryWriter(str(root))

    #######################
    # Load PyTorch Models #
    #######################
    netG = Generator(args.n_mel_channels, args.ngf,
                     args.n_residual_layers).cuda()
    netD = Discriminator(args.num_D, args.ndf, args.n_layers_D,
                         args.downsamp_factor).cuda()
    fft = Audio2Mel(n_mel_channels=args.n_mel_channels).cuda()

    print(netG)
    print(netD)

    #####################
    # Create optimizers #
    #####################
    optG = torch.optim.AdamW(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optD = torch.optim.AdamW(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))

    if load_root and load_root.exists():
        netG.load_state_dict(torch.load(load_root / "netG.pt"))
        optG.load_state_dict(torch.load(load_root / "optG.pt"))
        netD.load_state_dict(torch.load(load_root / "netD.pt"))
        optD.load_state_dict(torch.load(load_root / "optD.pt"))

    #######################
    # Create data loaders #
    #######################
    train_set = AudioDataset(Path(args.data_path) / "train_files.txt",
                             args.seq_len,
                             sampling_rate=22050)
    test_set = AudioDataset(
        Path(args.data_path) / "test_files.txt",
        22050 * 4,
        sampling_rate=22050,
        augment=False,
    )

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              num_workers=4)
    test_loader = DataLoader(test_set, batch_size=1)

    ##########################
    # Dumping original audio #
    ##########################
    test_voc = []
    test_audio = []
    for i, x_t in enumerate(test_loader):
        x_t = x_t.cuda()
        s_t = fft(x_t).detach()

        test_voc.append(s_t.cuda())
        test_audio.append(x_t)

        audio = x_t.squeeze().cpu()
        save_sample(root / ("original_%d.wav" % i), 22050, audio)
        writer.add_audio("original/sample_%d.wav" % i,
                         audio,
                         0,
                         sample_rate=22050)

        if i == args.n_test_samples - 1:
            break

    costs = []
    start = time.time()

    # enable cudnn autotuner to speed up training
    torch.backends.cudnn.benchmark = True

    best_mel_reconst = 1000000
    steps = 0
    for epoch in range(1, args.epochs + 1):
        for iterno, x_t in enumerate(train_loader):
            x_t = x_t.cuda()
            s_t = fft(x_t).detach()
            x_pred_t = netG(s_t.cuda())

            with torch.no_grad():
                s_pred_t = fft(x_pred_t.detach())
                s_error = F.l1_loss(s_t, s_pred_t).item()

            #######################
            # Train Discriminator #
            #######################
            D_fake_det = netD(x_pred_t.cuda().detach())
            D_real = netD(x_t.cuda())

            loss_D = 0
            for scale in D_fake_det:
                loss_D += F.relu(1 + scale[-1]).mean()

            for scale in D_real:
                loss_D += F.relu(1 - scale[-1]).mean()

            netD.zero_grad()
            loss_D.backward()
            optD.step()

            ###################
            # Train Generator #
            ###################
            D_fake = netD(x_pred_t.cuda())

            loss_G = 0
            for scale in D_fake:
                loss_G += -scale[-1].mean()

            loss_feat = 0
            feat_weights = 4.0 / (args.n_layers_D + 1)
            D_weights = 1.0 / args.num_D
            wt = D_weights * feat_weights
            for i in range(args.num_D):
                for j in range(len(D_fake[i]) - 1):
                    loss_feat += wt * F.l1_loss(D_fake[i][j],
                                                D_real[i][j].detach())

            netG.zero_grad()
            (loss_G + args.lambda_feat * loss_feat).backward()
            optG.step()

            ######################
            # Update tensorboard #
            ######################
            costs.append(
                [loss_D.item(),
                 loss_G.item(),
                 loss_feat.item(), s_error])

            writer.add_scalar("loss/discriminator", costs[-1][0], steps)
            writer.add_scalar("loss/generator", costs[-1][1], steps)
            writer.add_scalar("loss/feature_matching", costs[-1][2], steps)
            writer.add_scalar("loss/mel_reconstruction", costs[-1][3], steps)
            steps += 1

            if steps % args.save_interval == 0:
                st = time.time()
                with torch.no_grad():
                    for i, (voc, _) in enumerate(zip(test_voc, test_audio)):
                        pred_audio = netG(voc)
                        pred_audio = pred_audio.squeeze().cpu()
                        save_sample(root / ("generated_%d.wav" % i), 22050,
                                    pred_audio)
                        writer.add_audio(
                            "generated/sample_%d.wav" % i,
                            pred_audio,
                            epoch,
                            sample_rate=22050,
                        )

                torch.save(netG.state_dict(), root / "netG.pt")
                torch.save(optG.state_dict(), root / "optG.pt")

                torch.save(netD.state_dict(), root / "netD.pt")
                torch.save(optD.state_dict(), root / "optD.pt")

                if np.asarray(costs).mean(0)[-1] < best_mel_reconst:
                    best_mel_reconst = np.asarray(costs).mean(0)[-1]
                    torch.save(netD.state_dict(), root / "best_netD.pt")
                    torch.save(netG.state_dict(), root / "best_netG.pt")

                print("Took %5.4fs to generate samples" % (time.time() - st))
                print("-" * 100)

            if steps % args.log_interval == 0:
                print("Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}".
                      format(
                          epoch,
                          iterno,
                          len(train_loader),
                          1000 * (time.time() - start) / args.log_interval,
                          np.asarray(costs).mean(0),
                      ))
                costs = []
                start = time.time()
Ejemplo n.º 9
0
        # best_pesq = checkpoint['pesq']
        best_pesq = 0.0
    else:
        start_epoch = 0
        best_pesq = 0.0
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)

    # add graph to tensorboard
    if args.add_graph:
        dummy = torch.randn(16, 1, args.hop_length * 16).to(device)
        writer.add_graph(net, dummy)

    # define loss
    per_loss = PerceptualLoss(model_type=args.model_type)
    per_loss = per_loss.to(device)
    criterion = lambda y_hat, y: per_loss(y_hat, y) + F.l1_loss(y_hat, y)

    # iteration start
    for epoch in range(start_epoch, start_epoch + args.num_epochs, 1):
        # ------------- training -------------
        net.train()
        pbar = tqdm(train_dataloader,
                    bar_format='{l_bar}%s{bar}%s{r_bar}' %
                    (Fore.BLUE, Fore.RESET))
        pbar.set_description(f'Epoch {epoch + 1}')
        total_loss = 0.0
        if args.log_grad_norm:
            total_norm = 0.0
        net.zero_grad()
        for i, (n, c) in enumerate(pbar):
            n, c = n.to(device), c.to(device)
Ejemplo n.º 10
0
    def train(self):
        # Set data loader.
        data_loader = self.vcc_loader

        # Print logs in specified order
        keys = ['G/loss_id', 'G/loss_id_psnt', 'G/loss_cd']

        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(self.num_iters):

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # Fetch data.
            try:
                x_real, emb_org = next(data_iter)
            except:
                data_iter = iter(data_loader)
                x_real, emb_org = next(data_iter)

            x_real = x_real.to(self.device)
            emb_org = emb_org.to(self.device)

            # =================================================================================== #
            #                               2. Train the generator                                #
            # =================================================================================== #

            self.G = self.G.train()

            # Identity mapping loss
            x_identic, x_identic_psnt, code_real = self.G(
                x_real, emb_org, emb_org)
            g_loss_id = F.mse_loss(x_real, x_identic)
            g_loss_id_psnt = F.mse_loss(x_real, x_identic_psnt)

            # Code semantic loss.
            code_reconst = self.G(x_identic_psnt, emb_org, None)
            g_loss_cd = F.l1_loss(code_real, code_reconst)

            # Backward and optimize.
            g_loss = g_loss_id + g_loss_id_psnt + self.lambda_cd * g_loss_cd
            self.reset_grad()
            g_loss.backward()
            self.g_optimizer.step()

            # Logging.
            loss = {}
            loss['G/loss_id'] = g_loss_id.item()
            loss['G/loss_id_psnt'] = g_loss_id_psnt.item()
            loss['G/loss_cd'] = g_loss_cd.item()

            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #

            # Print out training information.
            if (i + 1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(
                    et, i + 1, self.num_iters)
                for tag in keys:
                    log += ", {}: {:.4f}".format(tag, loss[tag])
                print(log)

            # save model
            if (i + 1) % 1000 == 0:
                torch.save({"model": self.G.state_dict()}, "./autovc")
Ejemplo n.º 11
0
    def train_epoch(self):
        """
        Function to train the model for one epoch
        """
        self.model.train()
        self.netG.train()
        self.netD.train()

        for batch_idx, (datas, datat) in tqdm.tqdm(
            enumerate(itertools.izip(self.train_loader, self.target_loader)), total=min(len(self.target_loader), len(self.train_loader)),
            desc='Train epoch = %d' % self.epoch, ncols=80, leave=False):

            data_source, labels_source = datas
            data_target, __ = datat
            data_source_forD = torch.zeros((data_source.size()[0], 3, self.image_size_forD[1], self.image_size_forD[0]))            
            data_target_forD = torch.zeros((data_target.size()[0], 3, self.image_size_forD[1], self.image_size_forD[0]))
            
            # We pass the unnormalized data to the discriminator. So, the GANs produce images without data normalization
            for i in range(data_source.size()[0]):
                data_source_forD[i] = self.train_loader.dataset.transform_forD(data_source[i], self.image_size_forD, resize=False, mean_add=True)
                data_target_forD[i] = self.train_loader.dataset.transform_forD(data_target[i], self.image_size_forD, resize=False, mean_add=True)

            iteration = batch_idx + self.epoch * min(len(self.train_loader), len(self.target_loader))
            self.iteration = iteration

            if self.cuda:
                data_source, labels_source = data_source.cuda(), labels_source.cuda()
                data_target = data_target.cuda()
                data_source_forD = data_source_forD.cuda()
                data_target_forD = data_target_forD.cuda()
            
            data_source, labels_source = Variable(data_source), Variable(labels_source)
            data_target = Variable(data_target)
            data_source_forD = Variable(data_source_forD)
            data_target_forD = Variable(data_target_forD)



            # Source domain 
            score, fc7, pool4, pool3 = self.model(data_source)
            outG_src = self.netG(fc7, pool4, pool3)
            outD_src_fake_s, outD_src_fake_c = self.netD(outG_src)
            outD_src_real_s, outD_src_real_c = self.netD(data_source_forD)
            
            # target domain
            tscore, tfc7, tpool4, tpool3= self.model(data_target)
            outG_tgt = self.netG(tfc7, tpool4, tpool3)
            outD_tgt_real_s, outD_tgt_real_c = self.netD(data_target_forD)
            outD_tgt_fake_s, outD_tgt_fake_c = self.netD(outG_tgt)

            # Creating labels for D. We need two sets of labels since our model is a ACGAN style framework.
            # (1) Labels for the classsifier branch. This will be a downsampled version of original segmentation labels
            # (2) Domain lables for classifying source real, source fake, target real and target fake
            
            # Labels for classifier branch 
            Dout_sz = outD_src_real_s.size()
            label_forD = torch.zeros((outD_tgt_fake_c.size()[0], outD_tgt_fake_c.size()[2], outD_tgt_fake_c.size()[3]))
            for i in range(label_forD.size()[0]):
                label_forD[i] = self.train_loader.dataset.transform_label_forD(labels_source[i], (outD_tgt_fake_c.size()[2], outD_tgt_fake_c.size()[3]))
            if self.cuda:
                label_forD = label_forD.cuda()
            label_forD = Variable(label_forD.long())

            # Domain labels
            domain_labels_src_real = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()
            domain_labels_src_fake = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()+1
            domain_labels_tgt_real = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()+2
            domain_labels_tgt_fake = torch.LongTensor(Dout_sz[0],Dout_sz[2],Dout_sz[3]).zero_()+3

            domain_labels_src_real = Variable(domain_labels_src_real.cuda())
            domain_labels_src_fake = Variable(domain_labels_src_fake.cuda())
            domain_labels_tgt_real = Variable(domain_labels_tgt_real.cuda())
            domain_labels_tgt_fake = Variable(domain_labels_tgt_fake.cuda())

            
            # Updates.
            # There are three sets of updates - (1) Discriminator, (2) Generator and (3) F network
            
            # (1) Discriminator updates
            lossD_src_real_s = cross_entropy2d(outD_src_real_s, domain_labels_src_real, size_average=self.size_average)
            lossD_src_fake_s = cross_entropy2d(outD_src_fake_s, domain_labels_src_fake, size_average=self.size_average)
            lossD_src_real_c = cross_entropy2d(outD_src_real_c, label_forD, size_average=self.size_average)
            lossD_tgt_real = cross_entropy2d(outD_tgt_real_s, domain_labels_tgt_real, size_average=self.size_average)
            lossD_tgt_fake = cross_entropy2d(outD_tgt_fake_s, domain_labels_tgt_fake, size_average=self.size_average)           
            
            self.optimD.zero_grad()            
            lossD = lossD_src_real_s + lossD_src_fake_s + lossD_src_real_c + lossD_tgt_real + lossD_tgt_fake
            lossD /= len(data_source)
            lossD.backward(retain_graph=True)
            self.optimD.step()
        
            
            # (2) Generator updates
            self.optimG.zero_grad()            
            lossG_src_adv_s = cross_entropy2d(outD_src_fake_s, domain_labels_src_real,size_average=self.size_average)
            lossG_src_adv_c = cross_entropy2d(outD_src_fake_c, label_forD,size_average=self.size_average)
            lossG_tgt_adv_s = cross_entropy2d(outD_tgt_fake_s, domain_labels_tgt_real,size_average=self.size_average)
            lossG_src_mse = F.l1_loss(outG_src,data_source_forD)
            lossG_tgt_mse = F.l1_loss(outG_tgt,data_target_forD)

            lossG = lossG_src_adv_c + 0.1*(lossG_src_adv_s+ lossG_tgt_adv_s) + self.l1_weight * (lossG_src_mse + lossG_tgt_mse)
            lossG /= len(data_source)
            lossG.backward(retain_graph=True)
            self.optimG.step()

            # (3) F network updates 
            self.optim.zero_grad()            
            lossC = cross_entropy2d(score, labels_source,size_average=self.size_average)
            lossF_src_adv_s = cross_entropy2d(outD_src_fake_s, domain_labels_tgt_real,size_average=self.size_average)
            lossF_tgt_adv_s = cross_entropy2d(outD_tgt_fake_s, domain_labels_src_real,size_average=self.size_average)
            lossF_src_adv_c = cross_entropy2d(outD_src_fake_c, label_forD,size_average=self.size_average)
            
            lossF = lossC + self.adv_weight*(lossF_src_adv_s + lossF_tgt_adv_s) + self.c_weight*lossF_src_adv_c
            lossF /= len(data_source)
            lossF.backward()
            self.optim.step()
            
            if np.isnan(float(lossD.data[0])):
                raise ValueError('lossD is nan while training')
            if np.isnan(float(lossG.data[0])):
                raise ValueError('lossG is nan while training')
            if np.isnan(float(lossF.data[0])):
                raise ValueError('lossF is nan while training')
           
            # Computing metrics for logging
            metrics = []
            lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
            lbl_true = labels_source.data.cpu().numpy()
            for lt, lp in zip(lbl_true, lbl_pred):
                acc, acc_cls, mean_iu, fwavacc = \
                    torchfcn.utils.label_accuracy_score(
                        [lt], [lp], n_class=self.n_class)
                metrics.append((acc, acc_cls, mean_iu, fwavacc))
            metrics = np.mean(metrics, axis=0)

            # Logging
            with open(osp.join(self.out, 'log.csv'), 'a') as f:
                elapsed_time = (
                    datetime.datetime.now(pytz.timezone('Asia/Tokyo')) -
                    self.timestamp_start).total_seconds()
                log = [self.epoch, self.iteration] + [lossF.data[0]] + \
                    metrics.tolist() + [''] * 5 + [elapsed_time]
                log = map(str, log)
                f.write(','.join(log) + '\n')

            if self.iteration >= self.max_iter:
                break
            
            # Validating periodically
            if self.iteration % self.interval_validate == 0 and self.iteration > 0:
                out_recon = osp.join(self.out, 'visualization_viz')
                if not osp.exists(out_recon):
                    os.makedirs(out_recon)
                generations = []

                # Saving generated source and target images
                source_img = self.val_loader.dataset.untransform(data_source.data.cpu().numpy().squeeze())
                target_img = self.val_loader.dataset.untransform(data_target.data.cpu().numpy().squeeze())
                outG_src_ = (outG_src)*255.0
                outG_tgt_ = (outG_tgt)*255.0
                outG_src_ = outG_src_.data.cpu().numpy().squeeze().transpose((1,2,0))[:,:,::-1].astype(np.uint8)
                outG_tgt_ = outG_tgt_.data.cpu().numpy().squeeze().transpose((1,2,0))[:,:,::-1].astype(np.uint8)

                generations.append(source_img)
                generations.append(outG_src_)
                generations.append(target_img)
                generations.append(outG_tgt_)
                out_file = osp.join(out_recon, 'iter%012d_src_target_recon.png' % self.iteration)
                scipy.misc.imsave(out_file, fcn.utils.get_tile_image(generations))

                # Validation
                self.validate()
                self.model.train()
                self.netG.train()
Ejemplo n.º 12
0
 def consistency_loss(self, feat, feat_proj):
     # -- l1 loss
     return F.l1_loss(feat, feat_proj, reduction="none").mean(dim=(1, 2, 3))
Ejemplo n.º 13
0
Archivo: loss.py Proyecto: teboli/CPCR
 def forward(self, inputs, target):
     return F.l1_loss(inputs[-1], target)
Ejemplo n.º 14
0
    def generator_step(self, step_vars, model_out, critic_pred=None):
        losses = {}

        imgs, objs, boxes, obj_to_img, predicates, masks = step_vars
        imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out

        total_loss = torch.zeros(1).to(imgs)
        skip_pixel_loss = (boxes is None)
        # Pixel Loss
        l1_pixel_weight = self.l1_pixel_loss_weight
        if skip_pixel_loss:
            l1_pixel_weight = 0

        l1_pixel_loss = F.l1_loss(imgs_pred, imgs)

        total_loss = self.add_loss(total_loss, l1_pixel_loss, losses,
                                   'L1_pixel_loss', l1_pixel_weight)

        # Box Loss
        loss_bbox = F.mse_loss(boxes_pred, boxes)
        total_loss = self.add_loss(total_loss, loss_bbox, losses, 'bbox_pred',
                                   self.bbox_pred_loss_weight)

        if self.predicate_pred_loss_weight > 0:
            loss_predicate = F.cross_entropy(predicate_scores, predicates)
            total_loss = self.add_loss(total_loss, loss_predicate, losses,
                                       'predicate_pred',
                                       self.predicate_pred_loss_weight)

        if self.mask_loss_weight > 0 and masks is not None and masks_pred is not None:
            # Mask Loss
            mask_loss = F.binary_cross_entropy(masks_pred, masks.float())
            total_loss = self.add_loss(total_loss, mask_loss, losses,
                                       'mask_loss', self.mask_loss_weight)

        if self.obj_discriminator is not None:
            # OBJ AC Loss: Classification of Objects
            scores_fake, ac_loss = self.obj_discriminator(
                imgs_pred, objs, boxes, obj_to_img)
            total_loss = self.add_loss(total_loss, ac_loss, losses, 'ac_loss',
                                       self.ac_loss_weight)

            # OBJ GAN Loss: Real vs Fake
            weight = self.discriminator_loss_weight * self.d_obj_weight
            total_loss = self.add_loss(total_loss,
                                       self.gan_g_loss(scores_fake), losses,
                                       'g_gan_obj_loss', weight)

        if self.img_discriminator is not None:
            # IMG GAN Loss: Patches should be realistic
            scores_fake = self.img_discriminator(imgs_pred)
            weight = self.discriminator_loss_weight * self.d_img_weight
            total_loss = self.add_loss(total_loss,
                                       self.gan_g_loss(scores_fake), losses,
                                       'g_gan_img_loss', weight)

        if critic_pred is not None:
            # critic pred: (fake local, real local, fake global, real global)
            fake_local_pred, _, fake_global_pred, _ = critic_pred

            # Local Patch Loss
            local_loss = self.critic_g_loss(fake_local_pred)

            # Global Loss
            global_loss = self.critic_g_loss(fake_global_pred)

            critic_loss = self.critic_global_weight * global_loss + local_loss
            total_loss = self.add_loss(total_loss,
                                       self.critic_g_weight * critic_loss,
                                       losses, 'g_critic_loss')

        losses['total_loss'] = total_loss.item()

        self.reset_grad()
        total_loss.backward(retain_graph=True)
        # torch.nn.utils.clip_grad_norm_(self.generator.parameters(), 5)
        self.gen_optimizer.step()

        return total_loss, losses
            data_loader.buff_status[image_buff_read_index] = 'empty'

            image_buff_read_index = image_buff_read_index + 1
            if image_buff_read_index >= data_loader.image_buffer_size:
                image_buff_read_index = 0

        if is_gpu_mode:
            inputs = Variable(torch.from_numpy(input_img).float().cuda())
        else:
            inputs = Variable(torch.from_numpy(input_img).float())

        outputs = autoencoder_model(inputs)

        # l1-loss between real and fake
        l1_loss = F.l1_loss(outputs, inputs)

        # Before the backward pass, use the optimizer object to zero all of the
        # gradients for the variables it will update (which are the learnable weights
        # of the model)
        optimizer.zero_grad()

        # Backward pass: compute gradient of the loss with respect to model parameters
        l1_loss.backward(retain_graph=False)

        # Calling the step function on an Optimizer makes an update to its parameters
        optimizer.step()

        if i % 50 == 0:
            print '-----------------------------------------------'
            print 'iterations = ', str(i)
Ejemplo n.º 16
0
def reconstruction_loss(data, recon_data, distribution="bernoulli"):
    """
    Calculates the per image reconstruction loss for a batch of data. I.e. negative
    log likelihood.

    Parameters
    ----------
    data : torch.Tensor
        Input data (e.g. batch of images). Shape : (batch_size, n_chan,
        height, width).

    recon_data : torch.Tensor
        Reconstructed data. Shape : (batch_size, n_chan, height, width).

    distribution : {"bernoulli", "gaussian", "laplace"}
        Distribution of the likelihood on the each pixel. Implicitely defines the
        loss Bernoulli corresponds to a binary cross entropy (bse) loss and is the
        most commonly used. It has the issue that it doesn't penalize the same
        way (0.1,0.2) and (0.4,0.5), which might not be optimal. Gaussian
        distribution corresponds to MSE, and is sometimes used, but hard to train
        ecause it ends up focusing only a few pixels that are very wrong. Laplace
        distribution corresponds to L1 solves partially the issue of MSE.

    storer : dict
        Dictionary in which to store important variables for vizualisation.

    Returns
    -------
    loss : torch.Tensor
        Per image cross entropy (i.e. normalized per batch but not pixel and
        channel)
    """
    RECON_DIST = ["bernoulli", "gaussian", "laplace"]
    batch_size, n_chan, height, width = recon_data.size()
    is_colored = n_chan == 3

    if recon_data.min().detach().cpu().__array__() < 0:
        print('RE')
        print(recon_data.min())
    if data.min().detach().cpu().__array__() < 0:
        print('GT')
        print(data.min())

    if distribution == "bernoulli":
        loss = F.binary_cross_entropy(recon_data, data, reduction="sum")
        # try:
        #     loss = F.binary_cross_entropy(recon_data, data, reduction="sum")
        # except RuntimeError:
        #     print(index)
        #     from PIL import Image
        #     print(RuntimeError)
        #
        #     print(recon_data.shape)
        #     print(data.shape)
        #
        #     aa = np.array(recon_data.detach().cpu())
        #     bb = np.array(data.detach().cpu())
        #     for i in range(aa.shape[0]):
        #         re = Image.fromarray(aa[i].squeeze(0), mode='L')
        #         data = Image.fromarray(bb[i].squeeze(0), mode='L')
        #         re.show()
        #         data.show()



    elif distribution == "gaussian":
        # loss in [0,255] space but normalized by 255 to not be too big
        loss = F.mse_loss(recon_data * 255, data * 255, reduction="sum") / 255
    elif distribution == "laplace":
        # loss in [0,255] space but normalized by 255 to not be too big but
        # multiply by 255 and divide 255, is the same as not doing anything for L1
        loss = F.l1_loss(recon_data, data, reduction="sum")
        loss = loss * 3  # emperical value to give similar values than bernoulli => use same hyperparam
        loss = loss * (loss != 0)  # masking to avoid nan
    else:
        assert distribution not in RECON_DIST
        raise ValueError("Unkown distribution: {}".format(distribution))

    loss = loss / batch_size

    return loss
Ejemplo n.º 17
0
        # feedforward the inputs. generator
        outputs_gen = gen_model(inputs_with_mv)

        # feedforward the (input, answer) pairs. discriminator
        output_disc_real = disc_model(torch.cat((inputs_with_mv, answers), 1))
        output_disc_real_with_fake_vec = disc_model(torch.cat((inputs_with_fake_mv, answers), 1))
        output_disc_fake = disc_model(torch.cat((inputs_with_mv, outputs_gen), 1))
        output_disc_fake_with_fake_vec = disc_model(torch.cat((inputs_with_fake_mv, outputs_gen), 1))

        # loss functions

        # lsgan loss for the discriminator
        loss_disc_total_lsgan = 0.5 * (torch.mean((output_disc_real - 1)**2) + torch.mean(output_disc_fake**2) + 2 * torch.mean(output_disc_real_with_fake_vec**2) + 2 * torch.mean(output_disc_fake_with_fake_vec**2))

        # l1-loss between real and fake
        l1_loss = F.l1_loss(outputs_gen, answers)

        # vanilla gan loss for the generator
        #loss_gen_vanilla = F.binary_cross_entropy(output_disc_fake, ones_label)

        # lsgan loss for the generator
        loss_gen_lsgan = 0.5 * torch.mean((output_disc_fake - 1)**2)

        loss_gen_total_lsgan = 5 * loss_gen_lsgan + 0.01 * l1_loss

        # loss_disc_total = -torch.mean(torch.log(output_disc_real) + torch.log(1. - output_disc_fake))
        # loss_gen = -torch.mean(torch.log(output_disc_fake))

        # Before the backward pass, use the optimizer object to zero all of the
        # gradients for the variables it will update (which are the learnable weights
        # of the model)
Ejemplo n.º 18
0
def main():
    # ----------------------------------------
    # load kernels
    # ----------------------------------------
    #PSF_grid = np.load('./data/Schuler_PSF01.npz')['PSF']
    PSF_grid = np.load('./data/Schuler_PSF_facade.npz')['PSF']
    #PSF_grid = np.load('./data/Schuler_PSF03.npz')['PSF']
    #PSF_grid = np.load('./data/PSF.npz')['PSF']

    PSF_grid = PSF_grid.astype(np.float32)

    gx, gy = PSF_grid.shape[:2]
    for xx in range(gx):
        for yy in range(gy):
            PSF_grid[xx, yy] = PSF_grid[xx, yy] / np.sum(PSF_grid[xx, yy],
                                                         axis=(0, 1))

    # ----------------------------------------
    # load model
    # ----------------------------------------
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=8,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")
    model.load_state_dict(torch.load('./data/usrnet.pth'), strict=True)
    model.train()
    for _, v in model.named_parameters():
        v.requires_grad = True
    model = model.to(device)

    imgs = glob.glob('/home/xiu/databag/deblur/images/*/**.png',
                     recursive=True)
    imgs.sort()

    patch_size = 2 * 128
    num_patch = 2
    expand = PSF_grid.shape[2] // 2

    #positional alpha-beta parameters for HQS
    stage = 8
    ab_buffer = np.ones((gx, gy, 2 * stage + 1, 3), dtype=np.float32) * 0.1
    ab_buffer[:, :, 0, :] = 0.01
    ab = torch.tensor(ab_buffer, device=device, requires_grad=True)

    params = []
    params += [{"params": [ab], "lr": 1e-4}]
    for key, value in model.named_parameters():
        params += [{"params": [value], "lr": 1e-4}]

    optimizer = torch.optim.Adam(params, lr=1e-4)

    running = True

    while running:
        #alpha.beta
        img_idx = np.random.randint(len(imgs))
        img = imgs[img_idx]
        img_H = cv2.imread(img)
        w, h = img_H.shape[:2]

        mode = np.random.randint(5)
        px_start = np.random.randint(0, gx - num_patch + 1)
        py_start = np.random.randint(0, gy - num_patch + 1)
        if mode == 0:
            px_start = 0
        if mode == 1:
            px_start = gx - num_patch
        if mode == 2:
            py_start = 0
        if mode == 3:
            py_start = gy - num_patch

        x_start = np.random.randint(0, w - patch_size - expand * 2 + 1)
        y_start = np.random.randint(0, h - patch_size - expand * 2 + 1)
        PSF_patch = PSF_grid[px_start:px_start + num_patch,
                             py_start:py_start + num_patch]

        patch_H = img_H[x_start:x_start + patch_size + expand * 2,
                        y_start:y_start + patch_size + expand * 2]
        patch_L = util_deblur.blockConv2d(patch_H, PSF_patch, expand)

        block_size = patch_size // num_patch

        block_expand = max(patch_size // 16, expand)
        if block_expand > 0:
            patch_L_wrap = util_deblur.wrap_boundary_liu(
                patch_L,
                (patch_size + block_expand * 2, patch_size + block_expand * 2))
            #centralize
            patch_L_wrap = np.hstack(
                (patch_L_wrap[:, -block_expand:, :],
                 patch_L_wrap[:, :patch_size + block_expand, :]))
            patch_L_wrap = np.vstack(
                (patch_L_wrap[-block_expand:, :, :],
                 patch_L_wrap[:patch_size + block_expand, :, :]))
        else:
            patch_L_wrap = patch_L
        if block_expand > 0:
            x = util.uint2single(patch_L_wrap)
        else:
            x = util.uint2single(patch_L)
        x_blocky = []
        for h_ in range(num_patch):
            for w_ in range(num_patch):
                x_blocky.append(x[w_*block_size:w_*block_size+block_size+block_expand*2,\
                 h_*block_size:h_*block_size+block_size+block_expand*2:])
        x_blocky = [util.single2tensor4(el) for el in x_blocky]
        x_blocky = torch.cat(x_blocky, dim=0)

        k_all = []
        for w_ in range(num_patch):
            for h_ in range(num_patch):
                k_all.append(util.single2tensor4(PSF_patch[h_, w_]))
        k = torch.cat(k_all, dim=0)
        x_gt = util.uint2single(patch_H[expand:-expand, expand:-expand])
        x_gt = util.single2tensor4(x_gt)

        [x_blocky, x_gt, k] = [el.to(device) for el in [x_blocky, x_gt, k]]

        #cd = F.softplus(ab[px_start:px_start+num_patch,py_start:py_start+num_patch].reshape(num_patch**2,2*stage+1,1,1))
        #for n_iter in range(optim_iter):
        cd = F.softplus(ab[px_start:px_start + num_patch,
                           py_start:py_start + num_patch])
        cd = cd.view(num_patch**2, 2 * stage + 1, 3)
        x_E = model.forward_patchdeconv(x_blocky,
                                        k,
                                        cd, [num_patch, num_patch],
                                        patch_sz=patch_size // num_patch)
        loss = 0
        #for xx in x_E[::2]:
        loss = F.l1_loss(x_E[-2], x_gt)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('loss {}'.format(loss.item()))

        x_E = x_E[:-1]

        patch_L = patch_L_wrap.astype(np.uint8)
        patch_E = util.tensor2uint(x_E[-1])
        patch_E_all = [util.tensor2uint(pp) for pp in x_E]
        patch_E_z = np.hstack((patch_E_all[::2]))
        patch_E_x = np.hstack((patch_E_all[1::2]))

        patch_E_show = np.vstack((patch_E_z, patch_E_x))
        if block_expand > 0:
            show = np.hstack((patch_H[expand:-expand, expand:-expand],
                              patch_L[block_expand:-block_expand,
                                      block_expand:-block_expand], patch_E))
        else:
            show = np.hstack((patch_H[expand:-expand,
                                      expand:-expand], patch_L, patch_E))

        #get kernel
        #cv2.imshow('stage',patch_E_show)
        #rgb = np.hstack((patch_E[:,:,0],patch_E[:,:,1],patch_E[:,:,2]))
        cv2.imshow('HL', show)
        #cv2.imshow('RGB',rgb)
        key = cv2.waitKey(1)

        if key == ord('q'):
            running = False
            break

        if key == ord('s'):
            ab_numpy = ab.detach().cpu().numpy().flatten(
            )  #.reshape(-1,2*stage+1)
            torch.save(model.state_dict(), 'usrnet_facade.pth')
            np.savetxt('ab_facade.txt', ab_numpy)

    ab_numpy = ab.detach().cpu().numpy().flatten()  #.reshape(-1,2*stage+1)
    torch.save(model.state_dict(), 'usrnet_facade.pth')
    np.savetxt('ab_facade.txt', ab_numpy)
Ejemplo n.º 19
0
 def forward(self, x, t):
     y = self.model(x)
     prefix = 'train' if self.training else 'val'
     loss = F.l1_loss(y, t)
     ppe.reporting.report({prefix + '/loss': loss})
     return {'y': y, 'loss': loss}
Ejemplo n.º 20
0
    def stage2(self,
               source,
               target,
               source_pose,
               target_pose,
               other_pose,
               same_person_pose,
               index,
               k,
               get_image=False,
               fine_tune=False):
        imgs = torch.cat([source, source_pose], 1)

        self.optim_G.zero_grad()
        pose_latent = self.pose_encoder(target_pose)

        fake_sample = self.generator(imgs, pose_latent, k)[-1]

        real_score_256, real_score_128, real_layers = self.dis(
            target, index, target_pose, fine_tune)
        fake_score_256, fake_score_128, fake_layers = self.dis(
            fake_sample, index, target_pose, fine_tune)

        g_adv_256 = self.get_hinge_g(fake_score_256, real_score_256)
        g_adv_128 = self.get_hinge_g(fake_score_128, real_score_128)

        fm_loss = sum([
            F.l1_loss(fake_layer, real_layer.detach())
            for fake_layer, real_layer in zip(fake_layers, real_layers)
        ])
        perceptual_loss = sum([
            F.l1_loss(fake_layer, real_layer)
            for fake_layer, real_layer in zip(VGG(fake_sample), VGG(target))
        ])

        loss = g_adv_256 + g_adv_128 + fm_loss * 10 + perceptual_loss * 20
        loss.backward()
        self.optim_G.step()

        d1_adv_256, d1_adv_128 = self.optimize_D(target, fake_sample, index,
                                                 target_pose, fine_tune)
        with torch.no_grad():
            fake_sample2 = self.generator(imgs, pose_latent, k)[-1]
        d2_adv_256, d2_adv_128 = self.optimize_D(target, fake_sample2, index,
                                                 target_pose, fine_tune)

        res = [
            fm_loss.item(),
            perceptual_loss.item(),
            g_adv_256.item(),
            g_adv_128.item(),
            d1_adv_256,
            d1_adv_128,
            d2_adv_256,
            d2_adv_128,
        ]

        if get_image:
            res.append(fake_sample)
            with torch.no_grad():
                pose_latent = self.pose_encoder(other_pose)
                other_fake_sample = self.generator(imgs, pose_latent, k)[-1]
                res.append(other_fake_sample)
        return res
Ejemplo n.º 21
0
 def l1_loss(source, target, reduction="mean"):
     return F.l1_loss(*self.cropper(source, target, resolution),
                      reduction=reduction)
Ejemplo n.º 22
0
def ContentLoss(input, target):
    target = target.detach()
    loss = F.l1_loss(input, target)
    return loss
Ejemplo n.º 23
0
def tts_train_loop(paths: Paths,
                   model: Tacotron,
                   optimizer,
                   train_set,
                   lr,
                   train_steps,
                   attn_example,
                   warmup_lr=False):
    device = next(
        model.parameters()).device  # use same device as model parameters

    for g in optimizer.param_groups:
        g['lr'] = lr

    total_iters = len(train_set)
    epochs = train_steps // total_iters + 1
    if warmup_lr:
        lrs = OneCycleLR(optimizer,
                         lr,
                         total_steps=epochs * total_iters,
                         pct_start=0.5,
                         div_factor=1000,
                         anneal_strategy='cos',
                         final_div_factor=1)

    for e in range(1, epochs + 1):

        start = time.time()
        running_loss = 0

        # Perform 1 epoch
        for i, (x, m, ids, _) in enumerate(train_set, 1):

            x, m = x.to(device), m.to(device)

            # Parallelize model onto GPUS using workaround due to python bug
            if device.type == 'cuda' and torch.cuda.device_count() > 1:
                m1_hat, m2_hat, attention = data_parallel_workaround(
                    model, x, m)
            else:
                m1_hat, m2_hat, attention = model(x, m)

            m1_loss = F.l1_loss(m1_hat, m)
            m2_loss = F.l1_loss(m2_hat, m)

            loss = m1_loss + m2_loss

            optimizer.zero_grad()
            loss.backward()
            if hp.tts_clip_grad_norm is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), hp.tts_clip_grad_norm).item()
                if np.isnan(grad_norm):
                    print('grad_norm was NaN!')

            optimizer.step()
            if warmup_lr:
                lrs.step()

            running_loss += loss.item()
            avg_loss = running_loss / i

            speed = i / (time.time() - start)

            step = model.get_step()
            k = step // 1000

            if step % hp.tts_checkpoint_every == 0:
                ckpt_name = f'taco_step{k}K'
                save_checkpoint('tts',
                                paths,
                                model,
                                optimizer,
                                name=ckpt_name,
                                is_silent=True)

            if attn_example in ids:
                idx = ids.index(attn_example)
                save_attention(np_now(attention[idx][:, :160]),
                               paths.tts_attention / f'{step}')
                save_spectrogram(np_now(m2_hat[idx]),
                                 paths.tts_mel_plot / f'{step}', 600)

            msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:#.4} | {speed:#.2} steps/s | Step: {k}k | '
            stream(msg)

        # Must save latest optimizer state to ensure that resuming training
        # doesn't produce artifacts
        save_checkpoint('tts', paths, model, optimizer, is_silent=True)
        model.log(paths.tts_log, msg)
        print(' ')
Ejemplo n.º 24
0
def StyleLoss(input, target):
    target = GramMatrix(target).detach()
    input = GramMatrix(input)
    loss = F.l1_loss(input, target)
    return loss
def train(args, loader, generator, encoder, discriminator, discriminator2,
          vggnet, g_optim, e_optim, d_optim, d2_optim, g_ema, e_ema, device):
    # kwargs_d = {'detach_aux': args.detach_d_aux_head}
    if args.dataset == 'imagefolder':
        loader = sample_data2(loader)
    else:
        loader = sample_data(loader)

    if args.eval_every > 0:
        inception = nn.DataParallel(load_patched_inception_v3()).to(device)
        inception.eval()
        with open(args.inception, "rb") as f:
            embeds = pickle.load(f)
            real_mean = embeds["mean"]
            real_cov = embeds["cov"]
    else:
        inception = real_mean = real_cov = None
    mean_latent = None

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}
    avg_pix_loss = util.AverageMeter()
    avg_vgg_loss = util.AverageMeter()

    if args.distributed:
        g_module = generator.module
        e_module = encoder.module
        d_module = discriminator.module
    else:
        g_module = generator
        e_module = encoder
        d_module = discriminator

    d2_module = None
    if discriminator2 is not None:
        if args.distributed:
            d2_module = discriminator2.module
        else:
            d2_module = discriminator2

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256,
                                      device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)
    sample_x = load_real_samples(args, loader)
    sample_x1 = sample_x[:, 0, ...]
    sample_x2 = sample_x[:, -1, ...]
    sample_idx = torch.randperm(args.n_sample)

    n_step_max = max(args.n_step_d, args.n_step_e)

    requires_grad(g_ema, False)
    requires_grad(e_ema, False)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        frames = [get_batch(loader, device) for _ in range(n_step_max)]

        # Train Discriminator
        requires_grad(generator, False)
        requires_grad(encoder, False)
        requires_grad(discriminator, True)
        for step_index in range(args.n_step_d):
            frames1, frames2 = frames[step_index]
            real_img = frames1
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            if args.use_ema:
                g_ema.eval()
                fake_img, _ = g_ema(noise)
            else:
                fake_img, _ = generator(noise)
            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)
                fake_img, _ = augment(fake_img, ada_aug_p)
            else:
                real_img_aug = real_img
            fake_pred = discriminator(fake_img)
            real_pred = discriminator(real_img_aug)
            d_loss_fake = F.softplus(fake_pred).mean()
            d_loss_real = F.softplus(-real_pred).mean()
            loss_dict["real_score"] = real_pred.mean()
            loss_dict["fake_score"] = fake_pred.mean()

            d_loss_rec = 0.
            if args.lambda_rec_d > 0 and not args.decouple_d:  # Do not train D on x_rec if decouple_d
                if args.use_ema:
                    e_ema.eval()
                    g_ema.eval()
                    latent_real, _ = e_ema(real_img)
                    rec_img, _ = g_ema([latent_real], input_is_latent=True)
                else:
                    latent_real, _ = encoder(real_img)
                    rec_img, _ = generator([latent_real], input_is_latent=True)
                rec_pred = discriminator(rec_img)
                d_loss_rec = F.softplus(rec_pred).mean()
                loss_dict["rec_score"] = rec_pred.mean()

            d_loss_cross = 0.
            if args.lambda_cross_d > 0 and not args.decouple_d:
                if args.use_ema:
                    e_ema.eval()
                    w1, _ = e_ema(frames1)
                    w2, _ = e_ema(frames2)
                else:
                    w1, _ = encoder(frames1)
                    w2, _ = encoder(frames2)
                dw = w2 - w1
                dw_shuffle = dw[torch.randperm(args.batch), ...]
                if args.use_ema:
                    g_ema.eval()
                    cross_img, _ = g_ema([w1 + dw_shuffle],
                                         input_is_latent=True)
                else:
                    cross_img, _ = generator([w1 + dw_shuffle],
                                             input_is_latent=True)
                cross_pred = discriminator(cross_img)
                d_loss_cross = F.softplus(cross_pred).mean()

            d_loss_fake_cross = 0.
            if args.lambda_fake_cross_d > 0:
                if args.use_ema:
                    e_ema.eval()
                    w1, _ = e_ema(frames1)
                    w2, _ = e_ema(frames2)
                else:
                    w1, _ = encoder(frames1)
                    w2, _ = encoder(frames2)
                dw = w2 - w1
                noise = mixing_noise(args.batch, args.latent, args.mixing,
                                     device)
                if args.use_ema:
                    g_ema.eval()
                    style = g_ema.get_styles(noise).view(args.batch, -1)
                else:
                    style = generator.get_styles(noise).view(args.batch, -1)
                if dw.shape[1] < style.shape[1]:  # W space
                    dw = dw.repeat(1, args.n_latent)
                if args.use_ema:
                    cross_img, _ = g_ema([style + dw], input_is_latent=True)
                else:
                    cross_img, _ = generator([style + dw],
                                             input_is_latent=True)
                fake_cross_pred = discriminator(cross_img)
                d_loss_fake_cross = F.softplus(fake_cross_pred).mean()

            d_loss = (d_loss_real + d_loss_fake +
                      d_loss_fake_cross * args.lambda_fake_cross_d +
                      d_loss_rec * args.lambda_rec_d +
                      d_loss_cross * args.lambda_cross_d)
            loss_dict["d"] = d_loss

            discriminator.zero_grad()
            d_loss.backward()
            d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0
        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)
            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()
            d_optim.step()
        loss_dict["r1"] = r1_loss

        # Train Discriminator2
        if args.decouple_d and discriminator2 is not None:
            requires_grad(generator, False)
            requires_grad(encoder, False)
            requires_grad(discriminator2, True)
            for step_index in range(
                    args.n_step_e):  # n_step_d2 is same as n_step_e
                frames1, frames2 = frames[step_index]
                real_img = frames1
                if args.use_ema:
                    e_ema.eval()
                    g_ema.eval()
                    latent_real, _ = e_ema(real_img)
                    rec_img, _ = g_ema([latent_real], input_is_latent=True)
                else:
                    latent_real, _ = encoder(real_img)
                    rec_img, _ = generator([latent_real], input_is_latent=True)
                rec_pred = discriminator2(rec_img)
                d2_loss_rec = F.softplus(rec_pred).mean()
                real_pred1 = discriminator2(frames1)
                d2_loss_real = F.softplus(-real_pred1).mean()
                if args.use_frames2_d:
                    real_pred2 = discriminator2(frames2)
                    d2_loss_real += F.softplus(-real_pred2).mean()

                if args.use_ema:
                    e_ema.eval()
                    w1, _ = e_ema(frames1)
                    w2, _ = e_ema(frames2)
                else:
                    w1, _ = encoder(frames1)
                    w2, _ = encoder(frames2)
                dw = w2 - w1
                dw_shuffle = dw[torch.randperm(args.batch), ...]
                cross_img, _ = generator([w1 + dw_shuffle],
                                         input_is_latent=True)
                cross_pred = discriminator2(cross_img)
                d2_loss_cross = F.softplus(cross_pred).mean()

                d2_loss = d2_loss_real + d2_loss_rec + d2_loss_cross
                loss_dict["d2"] = d2_loss
                loss_dict["rec_score"] = rec_pred.mean()
                loss_dict["cross_score"] = cross_pred.mean()

                discriminator2.zero_grad()
                d2_loss.backward()
                d2_optim.step()

            d_regularize = i % args.d_reg_every == 0
            if d_regularize:
                real_img.requires_grad = True
                real_pred = discriminator2(real_img)
                r1_loss = d_r1_loss(real_pred, real_img)
                discriminator2.zero_grad()
                (args.r1 / 2 * r1_loss * args.d_reg_every +
                 0 * real_pred[0]).backward()
                d2_optim.step()

        # Train Encoder
        requires_grad(encoder, True)
        requires_grad(generator, args.train_ge)
        requires_grad(discriminator, False)
        if discriminator2 is not None:
            requires_grad(discriminator2, False)
        pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device)
        for step_index in range(args.n_step_e):
            frames1, frames2 = frames[step_index]
            real_img = frames1
            latent_real, _ = encoder(real_img)
            if args.use_ema:
                g_ema.eval()
                rec_img, _ = g_ema([latent_real], input_is_latent=True)
            else:
                rec_img, _ = generator([latent_real], input_is_latent=True)
            if args.lambda_adv > 0:
                if not args.decouple_d:
                    rec_pred = discriminator(rec_img)
                else:
                    rec_pred = discriminator2(rec_img)
                adv_loss = g_nonsaturating_loss(rec_pred)
            if args.lambda_pix > 0:
                if args.pix_loss == 'l2':
                    pix_loss = torch.mean((rec_img - real_img)**2)
                else:
                    pix_loss = F.l1_loss(rec_img, real_img)
            if args.lambda_vgg > 0:
                vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2)

            e_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + adv_loss * args.lambda_adv
            loss_dict["e"] = e_loss
            loss_dict["pix"] = pix_loss
            loss_dict["vgg"] = vgg_loss
            loss_dict["adv"] = adv_loss

            if args.train_ge:
                encoder.zero_grad()
                generator.zero_grad()
                e_loss.backward()
                e_optim.step()
                g_optim.step()
            else:
                encoder.zero_grad()
                e_loss.backward()
                e_optim.step()

        # Train Generator
        requires_grad(generator, True)
        requires_grad(discriminator, False)
        if discriminator2 is not None:
            requires_grad(discriminator2, False)
        frames1, frames2 = frames[0]
        real_img = frames1
        g_loss_fake = 0.
        if args.lambda_fake_g > 0:
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            fake_img, _ = generator(noise)
            if args.augment:
                fake_img, _ = augment(fake_img, ada_aug_p)
            fake_pred = discriminator(fake_img)
            g_loss_fake = g_nonsaturating_loss(fake_pred)

        g_loss_rec = 0.
        if args.lambda_rec_g > 0:
            if args.use_ema:
                e_ema.eval()
                latent_real, _ = e_ema(real_img)
            else:
                latent_real, _ = encoder(real_img)
            rec_img, _ = generator([latent_real], input_is_latent=True)
            if not args.decouple_d:
                rec_pred = discriminator(rec_img)
            else:
                rec_pred = discriminator2(rec_img)
            g_loss_rec = g_nonsaturating_loss(rec_pred)

        g_loss_cross = 0.
        if args.lambda_cross_g > 0:
            if args.use_ema:
                e_ema.eval()
                w1, _ = e_ema(frames1)
                w2, _ = e_ema(frames2)
            else:
                w1, _ = encoder(frames1)
                w2, _ = encoder(frames2)
            dw = w2 - w1
            dw_shuffle = dw[torch.randperm(args.batch), ...]
            cross_img, _ = generator([w1 + dw_shuffle], input_is_latent=True)
            if not args.decouple_d:
                cross_pred = discriminator(cross_img)
            else:
                cross_pred = discriminator2(cross_img)
            g_loss_cross = g_nonsaturating_loss(cross_pred)

        g_loss_fake_cross = 0.
        if args.lambda_fake_cross_g > 0:
            if args.use_ema:
                e_ema.eval()
                w1, _ = e_ema(frames1)
                w2, _ = e_ema(frames2)
            else:
                w1, _ = encoder(frames1)
                w2, _ = encoder(frames2)
            dw = w2 - w1
            noise = mixing_noise(args.batch, args.latent, args.mixing, device)
            style = generator.get_styles(noise).view(args.batch, -1)
            if dw.shape[1] < style.shape[1]:  # W space
                dw = dw.repeat(1, args.n_latent)
            cross_img, _ = generator([style + dw], input_is_latent=True)
            fake_cross_pred = discriminator(cross_img)
            g_loss_fake_cross = g_nonsaturating_loss(fake_cross_pred)

        g_loss = (g_loss_fake * args.lambda_fake_g +
                  g_loss_rec * args.lambda_rec_g +
                  g_loss_cross * args.lambda_cross_g +
                  g_loss_fake_cross * args.lambda_fake_cross_g)
        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)
            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)
            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
            weighted_path_loss.backward()
            g_optim.step()
            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())
        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(e_ema, e_module, accum)
        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)
        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()
        pix_loss_val = loss_reduced["pix"].mean().item()
        vgg_loss_val = loss_reduced["vgg"].mean().item()
        adv_loss_val = loss_reduced["adv"].mean().item()
        avg_pix_loss.update(pix_loss_val, real_img.shape[0])
        avg_vgg_loss.update(vgg_loss_val, real_img.shape[0])

        if get_rank() == 0:
            pbar.set_description((
                f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}; "
                f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}"
            ))

            if i % args.log_every == 0:
                with torch.no_grad():
                    latent_x, _ = e_ema(sample_x)
                    fake_x, _ = generator([latent_x],
                                          input_is_latent=True,
                                          return_latents=False)
                    sample_pix_loss = torch.sum((sample_x - fake_x)**2)
                with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f:
                    f.write(
                        f"{i:07d}; pix: {avg_pix_loss.avg}; vgg: {avg_vgg_loss.avg}; "
                        f"ref: {sample_pix_loss.item()};\n")

            if args.eval_every > 0 and i % args.eval_every == 0:
                with torch.no_grad():
                    g_ema.eval()
                    if args.truncation < 1:
                        mean_latent = g_ema.mean_latent(4096)
                    features = extract_feature_from_samples(
                        g_ema, inception, args.truncation, mean_latent, 64,
                        args.n_sample_fid, args.device).numpy()
                    sample_mean = np.mean(features, 0)
                    sample_cov = np.cov(features, rowvar=False)
                    fid = calc_fid(sample_mean, sample_cov, real_mean,
                                   real_cov)
                print("fid:", fid)
                with open(os.path.join(args.log_dir, 'log_fid.txt'),
                          'a+') as f:
                    f.write(f"{i:07d}; fid: {float(fid):.4f};\n")

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % args.log_every == 0:
                with torch.no_grad():
                    # Fixed fake samples
                    g_ema.eval()
                    sample, _ = g_ema([sample_z])
                    utils.save_image(
                        sample,
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-sample.png"),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        value_range=(-1, 1),
                    )
                    # Reconstruction samples
                    e_ema.eval()
                    nrow = int(args.n_sample**0.5)
                    nchw = list(sample_x.shape)[1:]
                    latent_real, _ = e_ema(sample_x)
                    fake_img, _ = g_ema([latent_real],
                                        input_is_latent=True,
                                        return_latents=False)
                    sample = torch.cat(
                        (sample_x.reshape(args.n_sample // nrow, nrow, *nchw),
                         fake_img.reshape(args.n_sample // nrow, nrow, *nchw)),
                        1)
                    utils.save_image(
                        sample.reshape(2 * args.n_sample, *nchw),
                        os.path.join(args.log_dir, 'sample',
                                     f"{str(i).zfill(6)}-recon.png"),
                        nrow=nrow,
                        normalize=True,
                        value_range=(-1, 1),
                    )

            if i % args.save_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "d2":
                        d2_module.state_dict() if args.decouple_d else None,
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "d2_optim":
                        d2_optim.state_dict() if args.decouple_d else None,
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight',
                                 f"{str(i).zfill(6)}.pt"),
                )

            if i % args.save_latest_every == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "e": e_module.state_dict(),
                        "d": d_module.state_dict(),
                        "d2":
                        d2_module.state_dict() if args.decouple_d else None,
                        "g_ema": g_ema.state_dict(),
                        "e_ema": e_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "e_optim": e_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "d2_optim":
                        d2_optim.state_dict() if args.decouple_d else None,
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                        "iter": i,
                    },
                    os.path.join(args.log_dir, 'weight', f"latest.pt"),
                )
def train():
    from config import IM_SIZE_AE, BATCH_SIZE_AE, NFC, NBR_CLS, DATALOADER_WORKERS, ITERATION_AE
    from config import SAVE_IMAGE_INTERVAL, SAVE_MODEL_INTERVAL, SAVE_FOLDER, TRIAL_NAME, LOG_INTERVAL
    from config import DATA_NAME
    from config import data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3

    dataset = PairedMultiDataset(data_root_colorful,
                                 data_root_sketch_1,
                                 data_root_sketch_2,
                                 data_root_sketch_3,
                                 im_size=IM_SIZE_AE,
                                 rand_crop=True)
    print(len(dataset))
    dataloader = iter(DataLoader(dataset, BATCH_SIZE_AE, \
        sampler=InfiniteSamplerWrapper(dataset), num_workers=DATALOADER_WORKERS, pin_memory=True))

    dataset_ss = SelfSupervisedDataset(data_root_colorful,
                                       data_root_sketch_3,
                                       im_size=IM_SIZE_AE,
                                       nbr_cls=NBR_CLS,
                                       rand_crop=True)
    print(len(dataset_ss), len(dataset_ss.frame))
    dataloader_ss = iter(DataLoader(dataset_ss, BATCH_SIZE_AE, \
        sampler=InfiniteSamplerWrapper(dataset_ss), num_workers=DATALOADER_WORKERS, pin_memory=True))

    style_encoder = StyleEncoder(nfc=NFC, nbr_cls=NBR_CLS).cuda()
    content_encoder = ContentEncoder(nfc=NFC).cuda()
    decoder = Decoder(nfc=NFC).cuda()

    opt_c = optim.Adam(content_encoder.parameters(),
                       lr=2e-4,
                       betas=(0.5, 0.999))
    opt_s = optim.Adam(style_encoder.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_d = optim.Adam(decoder.parameters(), lr=2e-4, betas=(0.5, 0.999))

    style_encoder.reset_cls()
    style_encoder.final_cls.cuda()

    from config import PRETRAINED_AE_PATH, PRETRAINED_AE_ITER
    if PRETRAINED_AE_PATH is not None:
        PRETRAINED_AE_PATH = PRETRAINED_AE_PATH + '/models/%d.pth' % PRETRAINED_AE_ITER
        ckpt = torch.load(PRETRAINED_AE_PATH)

        print(PRETRAINED_AE_PATH)

        style_encoder.load_state_dict(ckpt['s'])
        content_encoder.load_state_dict(ckpt['c'])
        decoder.load_state_dict(ckpt['d'])

        opt_c.load_state_dict(ckpt['opt_c'])
        opt_s.load_state_dict(ckpt['opt_s'])
        opt_d.load_state_dict(ckpt['opt_d'])
        print('loaded pre-trained AE')

    style_encoder.reset_cls()
    style_encoder.final_cls.cuda()
    opt_s_cls = optim.Adam(style_encoder.final_cls.parameters(),
                           lr=2e-4,
                           betas=(0.5, 0.999))

    saved_image_folder, saved_model_folder = make_folders(
        SAVE_FOLDER, 'AE_' + TRIAL_NAME)
    log_file_path = saved_image_folder + '/../ae_log.txt'
    log_file = open(log_file_path, 'w')
    log_file.close()
    ## for logging
    losses_sf_consist = AverageMeter()
    losses_cf_consist = AverageMeter()
    losses_cls = AverageMeter()
    losses_rec_rd = AverageMeter()
    losses_rec_org = AverageMeter()
    losses_rec_grey = AverageMeter()

    import lpips
    percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)

    for iteration in tqdm(range(ITERATION_AE)):

        if iteration % (
            (NBR_CLS * 100) // BATCH_SIZE_AE) == 0 and iteration > 1:
            dataset_ss._next_set()
            dataloader_ss = iter(
                DataLoader(dataset_ss,
                           BATCH_SIZE_AE,
                           sampler=InfiniteSamplerWrapper(dataset_ss),
                           num_workers=DATALOADER_WORKERS,
                           pin_memory=True))
            style_encoder.reset_cls()
            opt_s_cls = optim.Adam(style_encoder.final_cls.parameters(),
                                   lr=2e-4,
                                   betas=(0.5, 0.999))

            opt_s.param_groups[0]['lr'] = 1e-4
            opt_d.param_groups[0]['lr'] = 1e-4

        ### 1. train the encoder with self-supervision methods
        rgb_img_rd, rgb_img_org, skt_org, skt_bold, skt_erased, skt_erased_bold, img_idx = next(
            dataloader_ss)
        rgb_img_rd = rgb_img_rd.cuda()
        rgb_img_org = rgb_img_org.cuda()
        img_idx = img_idx.cuda()

        skt_org = F.interpolate(skt_org, size=512).cuda()
        skt_bold = F.interpolate(skt_bold, size=512).cuda()
        skt_erased = F.interpolate(skt_erased, size=512).cuda()
        skt_erased_bold = F.interpolate(skt_erased_bold, size=512).cuda()

        style_encoder.zero_grad()
        decoder.zero_grad()
        content_encoder.zero_grad()

        style_vector_rd, pred_cls_rd = style_encoder(rgb_img_rd)
        style_vector_org, pred_cls_org = style_encoder(rgb_img_org)

        content_feats = content_encoder(skt_org)
        content_feats_bold = content_encoder(skt_bold)
        content_feats_erased = content_encoder(skt_erased)
        content_feats_eb = content_encoder(skt_erased_bold)

        rd = random.randint(0, 3)
        gimg_rd = None
        if rd == 0:
            gimg_rd = decoder(content_feats, style_vector_rd)
        elif rd == 1:
            gimg_rd = decoder(content_feats_bold, style_vector_rd)
        elif rd == 2:
            gimg_rd = decoder(content_feats_erased, style_vector_rd)
        elif rd == 3:
            gimg_rd = decoder(content_feats_eb, style_vector_rd)


        loss_cf_consist = loss_for_list_perm(F.mse_loss, content_feats_bold, content_feats) +\
                            loss_for_list_perm(F.mse_loss, content_feats_erased, content_feats) +\
                                loss_for_list_perm(F.mse_loss, content_feats_eb, content_feats)

        loss_sf_consist = 0
        for loss_idx in range(3):
            loss_sf_consist += -F.cosine_similarity(style_vector_rd[loss_idx], style_vector_org[loss_idx].detach()).mean() + \
                                    F.cosine_similarity(style_vector_rd[loss_idx], style_vector_org[loss_idx][torch.randperm(BATCH_SIZE_AE)].detach()).mean()

        loss_cls = F.cross_entropy(pred_cls_rd, img_idx) + F.cross_entropy(
            pred_cls_org, img_idx)
        loss_rec_rd = F.mse_loss(gimg_rd, rgb_img_org)
        if DATA_NAME != 'shoe':
            loss_rec_rd += percept(
                F.adaptive_avg_pool2d(gimg_rd, output_size=256),
                F.adaptive_avg_pool2d(rgb_img_org, output_size=256)).sum()
        else:
            loss_rec_rd += F.l1_loss(gimg_rd, rgb_img_org)

        loss_total = loss_cls + loss_sf_consist + loss_rec_rd + loss_cf_consist  #+ loss_kl_c + loss_kl_s
        loss_total.backward()

        opt_s.step()
        opt_s_cls.step()
        opt_c.step()
        opt_d.step()

        ### 2. train as AutoEncoder
        rgb_img, skt_img_1, skt_img_2, skt_img_3 = next(dataloader)

        rgb_img = rgb_img.cuda()

        rd = random.randint(0, 3)
        if rd == 0:
            skt_img = skt_img_1
        elif rd == 1:
            skt_img = skt_img_2
        else:
            skt_img = skt_img_3

        skt_img = F.interpolate(skt_img, size=512).cuda()

        style_encoder.zero_grad()
        decoder.zero_grad()
        content_encoder.zero_grad()

        style_vector, _ = style_encoder(rgb_img)
        content_feats = content_encoder(skt_img)
        gimg = decoder(content_feats, style_vector)

        loss_rec_org = F.mse_loss(gimg, rgb_img)
        if DATA_NAME != 'shoe':
            loss_rec_org += percept(
                F.adaptive_avg_pool2d(gimg, output_size=256),
                F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum()
        #else:
        #    loss_rec_org += F.l1_loss(gimg, rgb_img)

        loss_rec = loss_rec_org
        if DATA_NAME == 'shoe':
            ### the grey image reconstruction
            perm = true_randperm(BATCH_SIZE_AE)
            gimg_perm = decoder(content_feats, [s[perm] for s in style_vector])
            gimg_grey = gimg_perm.mean(dim=1, keepdim=True)
            real_grey = rgb_img.mean(dim=1, keepdim=True)
            loss_rec_grey = F.mse_loss(gimg_grey, real_grey)
            loss_rec += loss_rec_grey
        loss_rec.backward()

        opt_s.step()
        opt_d.step()
        opt_c.step()

        ### Logging
        losses_cf_consist.update(loss_cf_consist.mean().item(), BATCH_SIZE_AE)
        losses_sf_consist.update(loss_sf_consist.mean().item(), BATCH_SIZE_AE)
        losses_cls.update(loss_cls.mean().item(), BATCH_SIZE_AE)
        losses_rec_rd.update(loss_rec_rd.item(), BATCH_SIZE_AE)
        losses_rec_org.update(loss_rec_org.item(), BATCH_SIZE_AE)
        if DATA_NAME == 'shoe':
            losses_rec_grey.update(loss_rec_grey.item(), BATCH_SIZE_AE)

        if iteration % LOG_INTERVAL == 0:
            log_msg = 'Train Stage 1: AE: \nrec_rd: %.4f  rec_org: %.4f  cls: %.4f  style_consist: %.4f  content_consist: %.4f  rec_grey: %.4f'%(losses_rec_rd.avg, \
                    losses_rec_org.avg, losses_cls.avg, losses_sf_consist.avg, losses_cf_consist.avg, losses_rec_grey.avg)

            print(log_msg)

            if log_file_path is not None:
                log_file = open(log_file_path, 'a')
                log_file.write(log_msg + '\n')
                log_file.close()

            losses_sf_consist.reset()
            losses_cls.reset()
            losses_rec_rd.reset()
            losses_rec_org.reset()
            losses_cf_consist.reset()
            losses_rec_grey.reset()

        if iteration % SAVE_IMAGE_INTERVAL == 0:
            vutils.save_image(torch.cat([
                rgb_img_rd,
                F.interpolate(skt_org.repeat(1, 3, 1, 1), size=512), gimg_rd
            ]),
                              '%s/rd_%d.jpg' % (saved_image_folder, iteration),
                              normalize=True,
                              range=(-1, 1))
            if DATA_NAME != 'shoe':
                with torch.no_grad():
                    perm = true_randperm(BATCH_SIZE_AE)
                    gimg_perm = decoder([c for c in content_feats],
                                        [s[perm] for s in style_vector])
            vutils.save_image(torch.cat([
                rgb_img,
                F.interpolate(skt_img.repeat(1, 3, 1, 1), size=512), gimg,
                gimg_perm
            ]),
                              '%s/org_%d.jpg' %
                              (saved_image_folder, iteration),
                              normalize=True,
                              range=(-1, 1))

        if iteration % SAVE_MODEL_INTERVAL == 0:
            print('Saving history model')
            torch.save(
                {
                    's': style_encoder.state_dict(),
                    'd': decoder.state_dict(),
                    'c': content_encoder.state_dict(),
                    'opt_c': opt_c.state_dict(),
                    'opt_s_cls': opt_s_cls.state_dict(),
                    'opt_s': opt_s.state_dict(),
                    'opt_d': opt_d.state_dict(),
                }, '%s/%d.pth' % (saved_model_folder, iteration))

    torch.save(
        {
            's': style_encoder.state_dict(),
            'd': decoder.state_dict(),
            'c': content_encoder.state_dict(),
            'opt_c': opt_c.state_dict(),
            'opt_s_cls': opt_s_cls.state_dict(),
            'opt_s': opt_s.state_dict(),
            'opt_d': opt_d.state_dict(),
        }, '%s/%d.pth' % (saved_model_folder, ITERATION_AE))
Ejemplo n.º 27
0
def main():
    G = Generator(zdim = args.zd).cuda()
    D = Discriminator().cuda()
    print("[INFO] Loaded G,D")
    
    params_G = list(filter(lambda p:p.requires_grad, G.parameters()))
    optimizer_G = optim.Adam(params_G, lr=0.0002, betas=(0.5,0.999))
    params_D = list(filter(lambda p: p.requires_grad, D.parameters()))
    optimizer_D = optim.Adam(params_D, lr=0.0002, betas=(0.5,0.999))

    if args.c is not None:
        G.load_state_dict(torch.load("./ckpts/{}_G.pth".format(args.c)).state_dict(), strict=True)
        D.load_state_dict(torch.load("./ckpts/{}_D.pth".format(args.c)).state_dict(), strict=True)
        print("Model restored")
    

    data_lists = os.listdir(args.dset)
    print(len(data_lists),args.dset)
    # shuffle(data_lists)
    print("[INFO] Got data and starting training")
    for epoch in tqdm(range(args.e)):
        for counter in range(int(len(data_lists)/args.nb)):
            # Data
            real_video = PIL.Image.open(f"{args.dset}/{b.strip()}")) / 127.5 - 1
            # # i = 0
            # for i in range(len(data_lists)):
            #     # print(data_lists[counter*args.nb + i])
            #     b = data_lists[counter*args.nb + i]
            #     # print(f"{args.dset}/{b.strip()}")
            #     # i += 1
            #     img = np.asarray(PIL.Image.open(f"{args.dset}/{b.strip()}")) / 127.5 - 1
            #     print(img.shape[0])
            #     frames = []
            #     if img.shape[0] < 128*32:
            #         continue
            #     print("INFO",len(frames))
            #     # print(args.d[1])
            #     for f in range(32):
            #         print(resize(img[f*128:(f+1)*128], (64,64), anti_aliasing=True))
            #         frames.append( resize(img[f*128:(f+1)*128], (64,64), anti_aliasing=True)  )
                    

            #     print(len(frames))
            #     real_video.append( np.stack(frames, 0) )

            #     if len(real_video) >= args.nb:
            #         break
            
            # print("rv", len(real_video))
            real_video = Variable(torch.from_numpy(np.stack(real_video, 0).astype(np.float32).transpose((0,4,1,2,3))), requires_grad=True).cuda()

            # D
            noise = torch.from_numpy(np.random.normal(0, 1, size=[args.nb,args.zd]).astype(np.float32)).cuda()
            with torch.no_grad():
                fake_video, f, b, m = G(noise)

            logit_real = D(real_video)
            logit_fake = D(fake_video.detach())

            prob_real = torch.mean(torch.sigmoid(logit_real))
            prob_fake = torch.mean(torch.sigmoid(logit_fake))

            loss_real = F.binary_cross_entropy_with_logits(logit_real, torch.ones_like(logit_real))
            loss_fake = F.binary_cross_entropy_with_logits(logit_fake, torch.zeros_like(logit_fake))
            loss_D = torch.mean(torch.stack([loss_real, loss_fake]))

            optimizer_D.zero_grad()
            loss_D.backward()
            optimizer_D.step()

            # G
            noise = torch.from_numpy(np.random.normal(0, 1, size=[args.nb,args.zd]).astype(np.float32)).cuda()
            gen_video, f, b, m = G(noise)

            logit_gen = D(gen_video)

            loss_gen = F.binary_cross_entropy_with_logits(logit_gen, torch.ones_like(logit_gen))
            loss_G = torch.mean(torch.stack([loss_gen])) + args.l*F.l1_loss(m, torch.zeros_like(m), True, True)

            optimizer_G.zero_grad()
            loss_G.backward()
            optimizer_G.step()

            # Print status
            print("Epoch {:d}/{:d} | Iter {:d}/{:d} | D {:.4e} | G {:.4e} | Real {:.2f} | Fake {:.2f}".format(epoch, args.e, counter, int(len(data_lists)/args.nb), loss_D, loss_G, prob_real, prob_fake))

            process_and_write_video(gen_video[0:1].cpu().data.numpy(), "curr_video")
            process_and_write_image(b.cpu().data.numpy(), "curr_bg")

        if (epoch+1) % args.s == 0:
            process_and_write_video(gen_video[0:1].cpu().data.numpy(), "epoch{}_iter{}_video".format(epoch, counter))
            process_and_write_image(b.cpu().data.numpy(), "epoch{}_iter{}_bg".format(epoch, counter))
                
            curr_time = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
            torch.save(G, "./ckpt/{}_{}_{}_G.pth".format(curr_time, epoch, counter))
            torch.save(D, "./ckpt/{}_{}_{}_D.pth".format(curr_time, epoch, counter))
            print ("Checkpoints saved")
Ejemplo n.º 28
0
    def train(self):
        loss = {}
        nrow = min(int(np.sqrt(self.batch_size)), 8)
        n_samples = nrow * nrow
        iter_per_epoch = len(self.train_loader.dataset) // self.batch_size
        max_iteration = self.num_epoch * iter_per_epoch
        lambda_l1 = 0.2
        print('Start training...')
        for epoch in tqdm(range(self.resume_epoch, self.num_epoch)):
            for i, (x_real, noise,
                    label) in enumerate(tqdm(self.train_loader)):

                # lr decay
                if epoch * iter_per_epoch + i >= self.lr_decay_start:
                    utils.decay_lr(self.g_optimizer, max_iteration,
                                   self.lr_decay_start, self.g_lr)
                    utils.decay_lr(self.d_optimizer, max_iteration,
                                   self.lr_decay_start, self.d_lr)
                    if i % 1000 == 0:
                        print('d_lr / g_lr is updated to {:.8f} / {:.8f} !'.
                              format(self.d_optimizer.param_groups[0]['lr'],
                                     self.g_optimizer.param_groups[0]['lr']))

                x_real = x_real.to(self.device)
                noise = noise.to(self.device)
                label = label.to(self.device)
                #'''
                # =================================================================================== #
                #							  1. Train the discriminator							  #
                # =================================================================================== #
                for param in self.D.parameters():
                    param.requires_grad = True

                dis_real, real_list = self.D(x_real, label)
                real_list = [h.detach() for h in real_list]

                x_fake = self.G(noise, label).detach()
                dis_fake, _ = self.D(x_fake, label)

                d_loss_real, d_loss_fake = self.dis_hinge(dis_real, dis_fake)

                # sample
                try:
                    x_real2, label2 = next(real_iter)
                except:
                    real_iter = iter(self.real_loader)
                    x_real2, label2 = next(real_iter)
                x_real2 = x_real2.to(self.device)
                label2 = label2.to(self.device)

                noise2 = torch.FloatTensor(utils.truncated_normal(self.batch_size*self.z_dim)) \
                      .view(self.batch_size, self.z_dim).to(self.device)
                #				 noise2 = torch.randn(self.batch_size, self.z_dim).to(self.device)
                dis_real2, _ = self.D(x_real2, label2)
                x_fake2 = self.G(noise2, label2).detach()
                dis_fake2, _ = self.D(x_fake2, label2)
                d_loss_real2, d_loss_fake2 = self.dis_hinge(
                    dis_real2, dis_fake2)

                # Backward and optimize.
                d_loss = d_loss_real + d_loss_fake + 0.2 * (d_loss_real2 +
                                                            d_loss_fake2)

                self.d_optimizer.zero_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging.
                loss['D/loss_real'] = d_loss_real.item()
                loss['D/loss_fake'] = d_loss_fake.item()
                loss['D/loss_real2'] = d_loss_real2.item()
                loss['D/loss_fake2'] = d_loss_fake2.item()

                # =================================================================================== #
                #								2. Train the generator								  #
                # =================================================================================== #
                #'''

                x_fake = self.G(noise, label)

                for param in self.D.parameters():
                    param.requires_grad = False

                dis_fake, fake_list = self.D(x_fake, label)

                g_loss_feat = self.KDLoss(real_list, fake_list)
                g_loss_pix = F.l1_loss(x_fake, x_real)
                g_loss = g_loss_feat + lambda_l1 * g_loss_pix
                loss['G/loss_ft'] = g_loss_feat.item()
                loss['G/loss_l1'] = g_loss_pix.item()

                if (i + 1) % self.n_critic == 0:
                    dis_fake, _ = self.D(x_fake, label)
                    g_loss_fake = self.gen_hinge(dis_fake)

                    g_loss += self.lambda_gan * g_loss_fake

                    # sample
                    noise2 = torch.FloatTensor(utils.truncated_normal(self.batch_size*self.z_dim)) \
                         .view(self.batch_size, self.z_dim).to(self.device)
                    #					 noise2 = torch.randn(self.batch_size, self.z_dim).to(self.device)
                    x_fake2 = self.G(noise2, label2)
                    dis_fake2, _ = self.D(x_fake2, label2)
                    g_loss_fake2 = self.gen_hinge(dis_fake2)
                    g_loss += 0.2 * self.lambda_gan * g_loss_fake2

                    loss['G/loss_fake'] = g_loss_fake.item()
                    loss['G/loss_fake2'] = g_loss_fake2.item()

                self.g_optimizer.zero_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # =================================================================================== #
                #								  3. Miscellaneous									  #
                # =================================================================================== #

                # Print out training information.
                if (i + 1) % self.log_step == 0:
                    log = "[{}/{}]".format(epoch, i)
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(tag, value, i + 1)

            if epoch == 0 or (epoch + 1) % self.sample_step == 0:
                with torch.no_grad():
                    """
					# randomly sampled noise
					noise = torch.FloatTensor(utils.truncated_normal(n_samples*self.z_dim)) \
										.view(n_samples, self.z_dim).to(self.device)
					label = label[:nrow].repeat(nrow)

					#label = np.random.choice(1000, nrow, replace=False)
					#label = torch.tensor(label).repeat(10).to(self.device)
					x_sample = self.G(noise, label)
					sample_path = os.path.join(self.sample_dir, '{}-sample.png'.format(epoch+1))
					save_image(utils.denorm(x_sample.cpu()), sample_path, nrow=nrow, padding=0)
					"""
                    # recons
                    n = min(x_real.size(0), 8)
                    comparison = torch.cat([x_real[:n], x_fake[:n]])
                    sample_path = os.path.join(
                        self.sample_dir, '{}-train.png'.format(epoch + 1))
                    save_image(utils.denorm(comparison.cpu()), sample_path)
                    print('Save fake images into {}...'.format(sample_path))

                    # noise2
                    comparison = torch.cat([x_real2[:n], x_fake2[:n]])
                    sample_path = os.path.join(
                        self.sample_dir, '{}-random.png'.format(epoch + 1))
                    save_image(utils.denorm(comparison.cpu()), sample_path)
                    print('Save fake images into {}...'.format(sample_path))

                    # noise sampled from BigGAN's test set
                    try:
                        x_real, noise, label = next(test_iter)
                    except:
                        test_iter = iter(self.test_loader)
                        x_real, noise, label = next(test_iter)
                    noise = noise.to(self.device)
                    label = label.to(self.device)

                    x_fake = self.G(noise, label).detach().cpu()
                    n = min(x_real.size(0), 8)
                    comparison = torch.cat([x_real[:n], x_fake[:n]])
                    sample_path = os.path.join(self.sample_dir,
                                               '{}-test.png'.format(epoch + 1))
                    save_image(utils.denorm(comparison.cpu()), sample_path)
                    print('Save fake images into {}...'.format(sample_path))

            lambda_l1 = max(0.00, lambda_l1 - 0.01)
            # Save model checkpoints.
            if (epoch + 1) % self.model_save_step == 0:
                utils.save_model(self.model_save_dir, epoch + 1, self.G,
                                 self.D, self.g_optimizer, self.d_optimizer)
Ejemplo n.º 29
0
def l1(output, target):
    return F.l1_loss(output, target)
Ejemplo n.º 30
0
    def cal_loss(self, xhr, cam_ext):
        
        ### reconstruction loss
        loss_rec = self.weight_loss_rec*F.l1_loss(xhr, self.xhr_rec)
        xh_rec = GeometryTransformer.convert_to_3D_rot(self.xhr_rec)

        ### vposer loss
        vposer_pose = xh_rec[:,16:48]
        loss_vposer = self.weight_loss_vposer * torch.mean(vposer_pose**2)


        ### contact loss
        body_param_rec = BodyParamParser.body_params_encapsulate_batch(xh_rec)
        joint_rot_batch = self.vposer.decode(body_param_rec['body_pose_vp'], 
                                           output_type='aa').view(self.batch_size, -1)
 
        body_param_ = {}
        for key in body_param_rec.keys():
            if key in ['body_pose_vp']:
                continue
            else:
                body_param_[key] = body_param_rec[key]

        smplx_output = self.body_mesh_model(return_verts=True, 
                                              body_pose=joint_rot_batch,
                                              **body_param_)
        body_verts_batch = smplx_output.vertices #[b, 10475,3]
        body_verts_batch = GeometryTransformer.verts_transform(body_verts_batch, cam_ext)

        vid, fid = GeometryTransformer.get_contact_id(
                                body_segments_folder=self.contact_id_folder,
                                contact_body_parts=self.contact_part)
        body_verts_contact_batch = body_verts_batch[:, vid, :]

        dist_chamfer_contact = ext.chamferDist()
        contact_dist, _ = dist_chamfer_contact(body_verts_contact_batch.contiguous(), 
                                                self.s_verts_batch.contiguous())

        loss_contact = self.weight_contact * torch.mean(torch.sqrt(contact_dist+1e-4)/(torch.sqrt(contact_dist+1e-4)+1.0))  


        ### sdf collision loss
        s_grid_min_batch = self.s_grid_min_batch.unsqueeze(1)
        s_grid_max_batch = self.s_grid_max_batch.unsqueeze(1)

        norm_verts_batch = (body_verts_batch - s_grid_min_batch) / (s_grid_max_batch - s_grid_min_batch) *2 -1
        n_verts = norm_verts_batch.shape[1]
        body_sdf_batch = F.grid_sample(self.s_sdf_batch.unsqueeze(1), 
                                        norm_verts_batch[:,:,[2,1,0]].view(-1, n_verts,1,1,3),
                                        padding_mode='border')


        # if there are no penetrating vertices then set sdf_penetration_loss = 0
        if body_sdf_batch.lt(0).sum().item() < 1:
            loss_sdf_pene = torch.tensor(0.0, dtype=torch.float32, device=self.device)
        else:
            loss_sdf_pene = body_sdf_batch[body_sdf_batch < 0].abs().mean()

        loss_collision = self.weight_collision*loss_sdf_pene


        return loss_rec, loss_vposer, loss_contact, loss_collision
            # feedforward the data to the discriminator_b
            output_disc_real_b = disc_model_b(torch.cat((condition_vectors, answers), 1))
            output_disc_fake_b = disc_model_b(torch.cat((condition_vectors, outputs_gen_a_to_b), 1))

            # loss functions

            # lsgan loss for the discriminator_a
            loss_disc_a_lsgan = 0.5 * (torch.mean((output_disc_real_a - 1) ** 2) + torch.mean(output_disc_fake_a ** 2))

            # lsgan loss for the discriminator_b
            loss_disc_b_lsgan = 0.5 * (torch.mean((output_disc_real_b - 1) ** 2) + torch.mean(output_disc_fake_b ** 2))

            # cycle-consistency loss(a)
            reconstructed_a = gen_model_b(torch.cat((condition_vectors, outputs_gen_a_to_b), 1))
            l1_loss_rec_a = F.l1_loss(reconstructed_a, inputs)

            # cycle-consistency loss(b)
            reconstructed_b = gen_model_a(torch.cat((condition_vectors, outputs_gen_b_to_a), 1))
            l1_loss_rec_b = F.l1_loss(reconstructed_b, answers)

            # lsgan loss for the generator_a
            loss_gen_lsgan_a = 0.5 * torch.mean((output_disc_fake_b - 1) ** 2)

            # lsgan loss for the generator_b
            loss_gen_lsgan_b = 0.5 * torch.mean((output_disc_fake_a - 1) ** 2)

            loss_gen_total_lsgan = loss_gen_lsgan_a + loss_gen_lsgan_b + 0.005 * (l1_loss_rec_a + l1_loss_rec_b)

            # discriminator_a
            # Before the backward pass, use the optimizer object to zero all of the
Ejemplo n.º 32
0
def train_phase(predictor, train, valid, args):

    # visualize
    plt.rcParams['font.size'] = 18
    plt.figure(figsize=(13, 5))
    ax = sns.scatterplot(x=train.x.ravel(),
                         y=train.y.ravel(),
                         color='blue',
                         s=55,
                         alpha=0.3)
    ax.plot(train.x.ravel(), train.t.ravel(), color='red', linewidth=2)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_xlim(-10, 10)
    ax.set_ylim(-15, 15)
    plt.legend(['Ground-truth', 'Observation'])
    plt.title('Training data set')
    plt.tight_layout()
    plt.savefig(os.path.join(args.out, 'train_dataset.png'))
    plt.close()

    # setup iterators
    train_iter = iterators.SerialIterator(train, args.batchsize, shuffle=True)
    valid_iter = iterators.SerialIterator(valid,
                                          args.batchsize,
                                          repeat=False,
                                          shuffle=False)

    # setup a model
    device = torch.device(args.gpu)

    lossfun = noised_mean_squared_error
    accfun = lambda y, t: F.l1_loss(y[0], t)

    model = Regressor(predictor, lossfun=lossfun, accfun=accfun)
    model.to(device)

    # setup an optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 weight_decay=max(args.decay, 0))

    # setup a trainer
    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                model,
                                                device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    trainer.extend(extensions.Evaluator(valid_iter, model, device=args.gpu))

    # trainer.extend(DumpGraph(model, 'main/loss'))

    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

    trainer.extend(extensions.LogReport())

    if args.plot and extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch',
                                  file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch',
                file_name='accuracy.png'))

        trainer.extend(
            extensions.PlotReport(
                ['main/predictor/sigma', 'validation/main/predictor/sigma'],
                'epoch',
                file_name='sigma.png'))

    trainer.extend(
        extensions.PrintReport([
            'epoch', 'iteration', 'main/loss', 'validation/main/loss',
            'main/accuracy', 'validation/main/accuracy',
            'main/predictor/sigma', 'validation/main/predictor/sigma',
            'elapsed_time'
        ]))

    trainer.extend(extensions.ProgressBar())

    if args.resume:
        trainer.load_state_dict(torch.load(args.resume))

    trainer.run()

    torch.save(predictor.state_dict(), os.path.join(args.out, 'predictor.pth'))
            # feedforward the data to the discriminator_b
            output_disc_real_b = disc_model_b(torch.cat((condition_vectors, answers), 1))
            output_disc_fake_b = disc_model_b(torch.cat((condition_vectors, outputs_gen_a_to_b), 1))

            # loss functions

            # lsgan loss for the discriminator_a
            loss_disc_a_lsgan = 0.5 * (torch.mean((output_disc_real_a - 1) ** 2) + torch.mean(output_disc_fake_a ** 2))

            # lsgan loss for the discriminator_b
            loss_disc_b_lsgan = 0.5 * (torch.mean((output_disc_real_b - 1) ** 2) + torch.mean(output_disc_fake_b ** 2))

            # cycle-consistency loss(a)
            reconstructed_a = gen_model_b(torch.cat((condition_vectors, outputs_gen_a_to_b), 1))
            l1_loss_rec_a = F.l1_loss(reconstructed_a, inputs)

            # cycle-consistency loss(b)
            reconstructed_b = gen_model_a(torch.cat((condition_vectors, outputs_gen_b_to_a), 1))
            l1_loss_rec_b = F.l1_loss(reconstructed_b, answers)

            # identity loss(a)
            l1_loss_identity_a = F.l1_loss(outputs_gen_a_to_a, inputs)
            
            # identity loss(b)
            l1_loss_identity_b = F.l1_loss(outputs_gen_b_to_b, answers)
            
            # lsgan loss for the generator_a
            loss_gen_lsgan_a = 0.5 * torch.mean((output_disc_fake_b - 1) ** 2)

            # lsgan loss for the generator_b
Ejemplo n.º 34
0
 def regression_mae(preds, targets):
     return F.l1_loss(preds, targets)
            output_disc_real_b = disc_model_b(torch.cat((condition_vectors, answers), 1))
            output_disc_fake_b = disc_model_b(torch.cat((condition_vectors, outputs_gen_a_to_b), 1))

            # loss functions

            # lsgan loss for the discriminator_a
            loss_disc_a_lsgan = 0.5 * (torch.mean((output_disc_real_a - 1) ** 2) + torch.mean(output_disc_fake_a ** 2))

            # lsgan loss for the discriminator_b
            loss_disc_b_lsgan = 0.5 * (torch.mean((output_disc_real_b - 1) ** 2) + torch.mean(output_disc_fake_b ** 2))

            loss_disc_total = loss_disc_a_lsgan + loss_disc_b_lsgan
            
            # cycle-consistency loss(a)
            reconstructed_a = gen_model_b(torch.cat((condition_vectors, outputs_gen_a_to_b), 1))
            l1_loss_rec_a = F.l1_loss(reconstructed_a, inputs)
            l1_loss_rec_a = l1_loss_rec_a * 0.005

            # cycle-consistency loss(b)
            reconstructed_b = gen_model_a(torch.cat((condition_vectors, outputs_gen_b_to_a), 1))
            l1_loss_rec_b = F.l1_loss(reconstructed_b, answers)
            l1_loss_rec_b = l1_loss_rec_b * 0.005

            # identity loss(a)
            l1_loss_identity_a = F.l1_loss(outputs_gen_b_to_b[:,1,:,:], answers[:,1,:,:])
            l1_loss_identity_a = 0.05 * l1_loss_identity_a
            
            l1_loss_identity_b = F.l1_loss(outputs_gen_b_to_b[:,2,:,:], answers[:,2,:,:])
            l1_loss_identity_b = 0.05 * l1_loss_identity_b
            
            # lsgan loss for the generator_a
Ejemplo n.º 36
0
def inference():
    netG = Generator(hp.num_mels, hp.vocoder_ngf, hp.vocoder_n_residual_layers).cuda()
    fft = Audio2Mel(n_fft=hp.n_fft,
                    hop_length=hp.hop_length,
                    win_length=hp.win_length,
                    sampling_rate=hp.sample_rate,
                    n_mel_channels=hp.num_mels,
                    mel_fmin=hp.fmin,
                    mel_fmax=hp.fmax,
                    min_level_db=hp.min_level_db).cuda()
    if args.model:
        logging.info(f'loading vocoder ckpt {args.model}')
        ckpt = torch.load(args.model)
    else:
        ckpts = sorted(list(Path(hp.vocoder_save_dir).glob('*.pt')))
        if len(ckpts) > 0:
            latest_ckpt_path = ckpts[-1]
            logging.info(f'loading vocoder ckpt {latest_ckpt_path}')
            ckpt = torch.load(latest_ckpt_path)
    if ckpt:
        netG.load_state_dict(ckpt['netG_state_dict'])
    else:
        print('no checkpoints')

    # seen sample
    # f = '/mnt/ssd500/dataset/speech/ST-CMDS-20170001_1-OS/20170001P00442A0027.wav'
    # f = '/mnt/ssd500/dataset/speech/ST-CMDS-20170001_1-OS/20170001P00435A0028.wav'

    # english (unseen speaker)
    # f = '/mnt/ssd500/dataset/speech/libritts/extract/LibriTTS/train-other-500/428/125877/428_125877_000015_000000.wav'
    # f = '/mnt/ssd500/dataset/speech/libritts/extract/LibriTTS/train-other-500/428/125877/428_125877_000087_000001.wav'
    # f = '/mnt/ssd500/dataset/speech/libritts/extract/LibriTTS/train-other-500/428/125877/428_125877_000065_000000.wav'

    # chinese (unseen speaker)
    f = '/mnt/ssd500/dataset/speech/ST_CMDS_holdout/20170001P00213I0037.wav'
    # f = '/mnt/wd500/dataset/speech/cn-celeb/extract/CN-Celeb/eval/enroll/id00987-enroll.wav'
    # f = '/mnt/wd500/dataset/speech/cn-celeb/extract/CN-Celeb/eval/enroll/id00998-enroll.wav'
    # f = '/mnt/wd500/dataset/speech/cn-celeb/extract/CN-Celeb/eval/enroll/id00960-enroll.wav'
    # f = '/mnt/wd500/dataset/speech/MAGICDATA-SLR68/extract/train/14_4030/14_4030_20170905174343.wav'

    # chinese (seen speaker)
    # f = '/mnt/ssd500/dataset/speech/ST_CMDS_holdout/20170001P00014A0120.wav'
    # f = '/mnt/ssd500/dataset/speech/ST_CMDS_holdout/20170001P00096I0120.wav'
    # f = '/mnt/ssd500/dataset/speech/ST_CMDS_holdout/20170001P00122A0120.wav'
    # f = '/mnt/ssd500/dataset/speech/aishell/data_aishell/wav/test/S0902/BAC009S0902W0477.wav'
    # f = '/mnt/ssd500/dataset/speech/primewords/extract/primewords_md_2018_set1/audio_files/0/0e/0eb1f442-f6b3-4e8c-abd7-e5720b4bdb99.wav'
    # f = '/mnt/ssd500/dataset/speech/ST-CMDS-20170001_1-OS/20170001P00179A0076.wav'
    # f = '/mnt/ssd500/dataset/speech/ST-CMDS-20170001_1-OS/20170001P00001A0003.wav'
    # f = '/mnt/ssd500/dataset/speech/ST-CMDS-20170001_1-OS/20170001P00001A0027.wav'
    # f = '/mnt/ssd500/dataset/speech/ST-CMDS-20170001_1-OS/20170001P00299I0067.wav'
    # f = '/mnt/ssd500/dataset/speech/ST-CMDS-20170001_1-OS/20170001P00299I0070.wav'
    # f = '/mnt/ssd500/dataset/speech/aishell/data_aishell/wav/train/S0169/BAC009S0169W0317.wav'
    # f = '/mnt/ssd500/dataset/speech/aishell/data_aishell/wav/train/S0169/BAC009S0169W0400.wav'
    uttrn = dataset.Utterance(id=None, raw_file=f)
    y = torch.from_numpy(uttrn.raw(sr=hp.sample_rate)).cuda()
    S = fft(y.unsqueeze(0).unsqueeze(0))
    # S = torch.from_numpy(uttrn.melspectrogram()).cuda().unsqueeze(0)
    y_pred = netG(S)
    S_recon = fft(y_pred)
    l1loss = F.l1_loss(S, S_recon)
    mseloss = F.mse_loss(S, S_recon)
    print(f'y.shape {y.shape}, S.shape {S.shape} y_pred.shape {y_pred.shape}')
    print(f'S.mean {S.mean()} S_recon.mean {S_recon.mean()} l1loss {l1loss:.5f} mseloss {mseloss:.5f}')
    results = [y.detach().cpu().numpy(), S.detach().cpu().numpy(), y_pred.detach().cpu().numpy()]
    with open('results.pkl', 'wb') as f:
        pickle.dump(results, f)