Exemplo n.º 1
0
def train(args):

    os.makedirs(args.checkpoint_dir, exist_ok=True)
    logging = GetLogging(args.logfile)

    train_dataset = CustomerDataset(args.input,
                                    upsample_factor=hop_length,
                                    local_condition=True,
                                    global_condition=False)

    device = torch.device("cuda" if args.use_cuda else "cpu")
    generator, discriminator = create_model(args)

    print(generator)
    print(discriminator)

    num_gpu = torch.cuda.device_count() if args.use_cuda else 1

    global_step = 0

    g_parameters = list(generator.parameters())
    g_optimizer = optim.Adam(g_parameters, lr=args.g_learning_rate)

    d_parameters = list(discriminator.parameters())
    d_optimizer = optim.Adam(d_parameters, lr=args.d_learning_rate)

    writer = Writer(args.checkpoint_dir, sample_rate=sample_rate)

    generator.to(device)
    discriminator.to(device)

    if args.resume is not None:
        restore_step = attempt_to_restore(generator, discriminator,
                                          g_optimizer, d_optimizer,
                                          args.resume, args.use_cuda, logging)
        global_step = restore_step

    customer_g_optimizer = Optimizer(g_optimizer, args.g_learning_rate,
                                     global_step, args.warmup_steps,
                                     args.decay_learning_rate)
    customer_d_optimizer = Optimizer(d_optimizer, args.d_learning_rate,
                                     global_step, args.warmup_steps,
                                     args.decay_learning_rate)

    criterion = nn.MSELoss().to(device)
    stft_criterion = MultiResolutionSTFTLoss()

    for epoch in range(args.epochs):

        collate = CustomerCollate(upsample_factor=hop_length,
                                  condition_window=args.condition_window,
                                  local_condition=True,
                                  global_condition=False)

        train_data_loader = DataLoader(train_dataset,
                                       collate_fn=collate,
                                       batch_size=args.batch_size,
                                       num_workers=args.num_workers,
                                       shuffle=True,
                                       pin_memory=True)

        #train one epoch
        for batch, (samples, conditions) in enumerate(train_data_loader):

            start = time.time()
            batch_size = int(conditions.shape[0] // num_gpu * num_gpu)

            samples = samples[:batch_size, :].to(device)
            conditions = conditions[:batch_size, :, :].to(device)

            losses = {}

            if num_gpu > 1:
                g_outputs = parallel(generator, (conditions, ))
            else:
                g_outputs = generator(conditions)

            sc_loss, mag_loss = stft_criterion(g_outputs.squeeze(1),
                                               samples.squeeze(1))

            g_loss = sc_loss + mag_loss

            losses['sc_loss'] = sc_loss.item()
            losses['mag_loss'] = mag_loss.item()
            losses['g_loss'] = g_loss.item()

            customer_g_optimizer.zero_grad()
            g_loss.backward()
            nn.utils.clip_grad_norm_(g_parameters, max_norm=0.5)
            customer_g_optimizer.step_and_update_lr()

            time_used = time.time() - start

            logging.info(
                "Step: {} --sc_loss: {:.3f} --mag_loss: {:.3f} --Time: {:.2f} seconds"
                .format(global_step, sc_loss, mag_loss, time_used))

            if global_step % args.checkpoint_step == 0:
                save_checkpoint(args, generator, discriminator, g_optimizer,
                                d_optimizer, global_step, logging)

            if global_step % args.summary_step == 0:
                writer.logging_loss(losses, global_step)
                target = samples.cpu().detach()[0, 0].numpy()
                predict = g_outputs.cpu().detach()[0, 0].numpy()
                writer.logging_audio(target, predict, global_step)
                writer.logging_histogram(generator, global_step)
                writer.logging_histogram(discriminator, global_step)

            global_step += 1
Exemplo n.º 2
0
def train(args):

    os.makedirs(args.checkpoint_dir, exist_ok=True)

    train_dataset = CustomerDataset(
         args.input,
         upsample_factor=hop_length,
         local_condition=True,
         global_condition=False)

    device = torch.device("cuda" if args.use_cuda else "cpu")
    generator, discriminator = create_model(args)

    print(generator)
    print(discriminator)

    num_gpu = torch.cuda.device_count() if args.use_cuda else 1

    global_step = 0

    g_parameters = list(generator.parameters())
    g_optimizer = optim.Adam(g_parameters, lr=args.g_learning_rate)

    d_parameters = list(discriminator.parameters())
    d_optimizer = optim.Adam(d_parameters, lr=args.d_learning_rate)
    
    writer = SummaryWriter(args.checkpoint_dir)

    generator.to(device)
    discriminator.to(device)

    if args.resume is not None:
        restore_step = attempt_to_restore(generator, discriminator, g_optimizer,
                                          d_optimizer, args.resume, args.use_cuda)
        global_step = restore_step

    customer_g_optimizer = Optimizer(g_optimizer, args.g_learning_rate,
                global_step, args.warmup_steps, args.decay_learning_rate)
    customer_d_optimizer = Optimizer(d_optimizer, args.d_learning_rate,
                global_step, args.warmup_steps, args.decay_learning_rate)

    stft_criterion = MultiResolutionSTFTLoss().to(device)
    criterion = nn.MSELoss().to(device)

    for epoch in range(args.epochs):

       collate = CustomerCollate(
           upsample_factor=hop_length,
           condition_window=args.condition_window,
           local_condition=True,
           global_condition=False)

       train_data_loader = DataLoader(train_dataset, collate_fn=collate,
               batch_size=args.batch_size, num_workers=args.num_workers,
               shuffle=True, pin_memory=True)

       #train one epoch
       for batch, (samples, conditions) in enumerate(train_data_loader):

            start = time.time()
            batch_size = int(conditions.shape[0] // num_gpu * num_gpu)

            samples = samples[:batch_size, :].to(device)
            conditions = conditions[:batch_size, :, :].to(device)
            z = torch.randn(batch_size, args.z_dim).to(device)

            losses = {}

            if num_gpu > 1:
                g_outputs = parallel(generator, (conditions, z))
            else:
                g_outputs = generator(conditions, z)

            #train discriminator
            if global_step > args.discriminator_train_start_steps:
                if num_gpu > 1:
                    real_outputs, fake_outputs = \
                        parallel(discriminator, (samples, g_outputs.detach(), conditions))
                else:
                    real_outputs, fake_outputs = \
                        discriminator(samples, g_outputs.detach(), conditions)

                fake_loss, real_loss = [], []
                for (fake_output, real_output) in zip(fake_outputs, real_outputs):
                    fake_loss.append(criterion(fake_output, torch.zeros_like(fake_output)))
                    real_loss.append(criterion(real_output, torch.ones_like(real_output)))
                #fake_loss = sum(fake_loss) / 10.0
                #real_loss = sum(real_loss) / 10.0
                fake_loss = sum(fake_loss)
                real_loss = sum(real_loss)

                d_loss = fake_loss + real_loss

                customer_d_optimizer.zero_grad()
                d_loss.backward()
                nn.utils.clip_grad_norm_(d_parameters, max_norm=0.5)
                customer_d_optimizer.step_and_update_lr()
            else:
                d_loss = torch.Tensor([0])
                fake_loss = torch.Tensor([0])
                real_loss = torch.Tensor([0])

            losses['fake_loss'] = fake_loss.item()
            losses['real_loss'] = real_loss.item()
            losses['d_loss'] = d_loss.item()

            #train generator
            if num_gpu > 1:
                _, fake_outputs = parallel(discriminator, (samples, g_outputs, conditions))
            else:
                _, fake_outputs = discriminator(samples, g_outputs, conditions)

            adv_loss = []
            for fake_output in fake_outputs:
               adv_loss.append(criterion(fake_output, torch.ones_like(fake_output)))

            #adv_loss = sum(adv_loss) / 10.0
            adv_loss = sum(adv_loss)

            sc_loss, mag_loss = stft_criterion(g_outputs.squeeze(1), samples.squeeze(1))

            if global_step > args.discriminator_train_start_steps:
               g_loss = adv_loss * args.lamda_adv + sc_loss + mag_loss 
            else:
               g_loss = sc_loss + mag_loss

            losses['adv_loss'] = adv_loss.item()
            losses['sc_loss'] = sc_loss
            losses['mag_loss'] = mag_loss
            losses['g_loss'] = g_loss.item()
 
            customer_g_optimizer.zero_grad()
            g_loss.backward()
            nn.utils.clip_grad_norm_(g_parameters, max_norm=0.5)
            customer_g_optimizer.step_and_update_lr()

            time_used = time.time() - start
            if global_step > args.discriminator_train_start_steps:
                print("Step: {} --adv_loss: {:.3f} --real_loss: {:.3f} --fake_loss: {:.3f} --sc_loss: {:.3f} --mag_loss: {:.3f} --Time: {:.2f} seconds".format(
                   global_step, adv_loss, real_loss, fake_loss, sc_loss, mag_loss, time_used))
            else:
                print("Step: {} --sc_loss: {:.3f} --mag_loss: {:.3f} --Time: {:.2f} seconds".format(global_step, sc_loss, mag_loss, time_used))

            global_step += 1

            if global_step % args.checkpoint_step == 0:
                save_checkpoint(args, generator, discriminator,
                         g_optimizer, d_optimizer, global_step)
                
            if global_step % args.summary_step == 0:
                for key in losses:
                    writer.add_scalar('{}'.format(key), losses[key], global_step)
Exemplo n.º 3
0
def train(args):

    os.makedirs(args.checkpoint_dir, exist_ok=True)
    os.makedirs(args.ema_checkpoint_dir, exist_ok=True)

    train_dataset = CustomerDataset(args.input,
                                    upsample_factor=hop_length,
                                    local_condition=True,
                                    global_condition=False)

    device = torch.device("cuda" if args.use_cuda else "cpu")
    generator, discriminator = create_model(args)

    print(generator)
    print(discriminator)

    num_gpu = torch.cuda.device_count() if args.use_cuda else 1

    global_step = 0

    g_parameters = list(generator.parameters())
    g_optimizer = optim.Adam(g_parameters, lr=args.g_learning_rate)

    d_parameters = list(discriminator.parameters())
    d_optimizer = optim.Adam(d_parameters, lr=args.d_learning_rate)

    writer = SummaryWriter(args.checkpoint_dir)

    generator.to(device)
    discriminator.to(device)

    if args.resume is not None:
        restore_step = attempt_to_restore(generator, discriminator,
                                          g_optimizer, d_optimizer,
                                          args.resume, args.use_cuda)
        global_step = restore_step

    ema = ExponentialMovingAverage(args.ema_decay)
    register_model_to_ema(generator, ema)

    customer_g_optimizer = Optimizer(g_optimizer, args.g_learning_rate,
                                     global_step, args.warmup_steps,
                                     args.decay_learning_rate)
    customer_d_optimizer = Optimizer(d_optimizer, args.d_learning_rate,
                                     global_step, args.warmup_steps,
                                     args.decay_learning_rate)

    criterion = nn.MSELoss().to(device)

    for epoch in range(args.epochs):

        collate = CustomerCollate(upsample_factor=hop_length,
                                  condition_window=args.condition_window,
                                  local_condition=True,
                                  global_condition=False)

        train_data_loader = DataLoader(train_dataset,
                                       collate_fn=collate,
                                       batch_size=args.batch_size,
                                       num_workers=args.num_workers,
                                       shuffle=True,
                                       pin_memory=True)

        #train one epoch
        for batch, (samples, conditions) in enumerate(train_data_loader):

            start = time.time()
            batch_size = int(conditions.shape[0] // num_gpu * num_gpu)

            samples = samples[:batch_size, :].to(device)
            conditions = conditions[:batch_size, :, :].to(device)
            z = torch.randn(batch_size, args.z_dim).to(device)

            #train generator
            if num_gpu > 1:
                g_outputs = parallel(generator, (conditions, z))
                _, fake_outputs, real_features, fake_features = \
                   parallel(discriminator, (samples, g_outputs, conditions))
            else:
                g_outputs = generator(conditions, z)
                _, fake_outputs, real_features, fake_features = \
                   discriminator(samples, g_outputs, conditions)

            g_d_loss = []
            for fake_output in fake_outputs:
                target = torch.ones_like(fake_output).to(device)
                g_d_loss.append(criterion(fake_output, target))
            feature_loss = feature_loss_calculate(real_features, fake_features)
            g_loss = feature_loss * args.lamda + sum(g_d_loss)

            customer_g_optimizer.zero_grad()
            g_loss.backward()
            nn.utils.clip_grad_norm_(g_parameters, max_norm=0.5)
            customer_g_optimizer.step_and_update_lr()

            #train discriminator
            g_outputs = g_outputs.detach()
            if num_gpu > 1:
                real_outputs, fake_outputs, _, _ = \
                    parallel(discriminator, (samples, g_outputs, conditions))
            else:
                real_outputs, fake_outputs, _, _ = \
                    discriminator(samples, g_outputs, conditions)

            fake_loss, real_loss = [], []
            for (fake_output, real_output) in zip(fake_outputs, real_outputs):
                fake_target = torch.zeros_like(fake_output).to(device)
                real_target = torch.ones_like(real_output).to(device)
                fake_loss.append(criterion(fake_output, fake_target))
                real_loss.append(criterion(real_output, real_target))
            d_loss = sum(fake_loss) + sum(real_loss)

            customer_d_optimizer.zero_grad()
            d_loss.backward()
            nn.utils.clip_grad_norm_(d_parameters, max_norm=0.5)
            customer_d_optimizer.step_and_update_lr()

            global_step += 1

            print(
                "Step: {} --g_loss: {:.3f} --d_loss: {:.3f} --Time: {:.2f} seconds"
                .format(global_step, g_loss, d_loss,
                        float(time.time() - start)))
            print(feature_loss.item(), sum(g_d_loss).item(), d_loss.item())
            if ema is not None:
                apply_moving_average(generator, ema)

            if global_step % args.checkpoint_step == 0:
                save_checkpoint(args, generator, discriminator, g_optimizer,
                                d_optimizer, global_step, ema)

            if global_step % args.summary_step == 0:
                writer.add_scalar("g_loss", g_loss.item(), global_step)
                writer.add_scalar("d_loss", d_loss.item(), global_step)
Exemplo n.º 4
0
def train(args):

    os.makedirs(args.checkpoint_dir, exist_ok=True)
    os.makedirs(args.ema_checkpoint_dir, exist_ok=True)

    train_dataset = WaveRNNDataset(args.input,
                                   upsample_factor=hop_length,
                                   local_condition=True,
                                   global_condition=False)

    device = torch.device("cuda" if args.use_cuda else "cpu")
    model = create_model(args)

    print(model)

    num_gpu = torch.cuda.device_count() if args.use_cuda else 1

    model.train(mode=True)

    global_step = 0

    parameters = list(model.parameters())
    optimizer = optim.Adam(parameters, lr=args.learning_rate)

    writer = SummaryWriter(args.checkpoint_dir)

    model.to(device)

    if args.resume is not None:
        restore_step = attempt_to_restore(model, optimizer, args.resume,
                                          args.use_cuda)
        global_step = restore_step

    ema = ExponentialMovingAverage(args.ema_decay)
    register_model_to_ema(model, ema)

    customer_optimizer = Optimizer(optimizer, args.learning_rate, global_step,
                                   args.warmup_steps, args.decay_learning_rate)

    criterion = nn.NLLLoss().to(device)

    for epoch in range(args.epochs):

        collate = WaveRNNCollate(upsample_factor=hop_length,
                                 condition_window=args.condition_window,
                                 local_condition=True,
                                 global_condition=False)

        train_data_loader = DataLoader(train_dataset,
                                       collate_fn=collate,
                                       batch_size=args.batch_size,
                                       num_workers=args.num_workers,
                                       shuffle=True,
                                       pin_memory=True)

        #train one epoch
        for batch, (coarse, fine, condition) in enumerate(train_data_loader):

            start = time.time()
            batch_size = int(condition.shape[0] // num_gpu * num_gpu)

            coarse = coarse[:batch_size, :].to(device)
            fine = fine[:batch_size, :].to(device)
            condition = condition[:batch_size, :, :].to(device)
            inputs = torch.cat([
                coarse[:, :-1].unsqueeze(-1), fine[:, :-1].unsqueeze(-1),
                coarse[:, 1:].unsqueeze(-1)
            ],
                               dim=-1)
            inputs = 2 * inputs.float() / 255 - 1.0

            if num_gpu > 1:
                out_c, out_f, _ = parallel(model, (inputs, condition))
            else:
                out_c, out_f, _ = model(inputs, condition)

            loss_c = criterion(out_c.transpose(1, 2).float(), coarse[:, 1:])
            loss_f = criterion(out_f.transpose(1, 2).float(), fine[:, 1:])
            loss = loss_c + loss_f

            global_step += 1
            customer_optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(parameters, max_norm=0.5)
            customer_optimizer.step_and_update_lr()
            model.after_update()

            if ema is not None:
                apply_moving_average(model, ema)

            print(
                "Step: {} --loss_c: {:.3f} --loss_f: {:.3f} --Lr: {:g} --Time: {:.2f} seconds"
                .format(global_step, loss_c, loss_f, customer_optimizer.lr,
                        float(time.time() - start)))

            if global_step % args.checkpoint_step == 0:
                save_checkpoint(args, model, optimizer, global_step, ema)

            if global_step % args.summary_step == 0:
                writer.add_scalar("loss", loss.item(), global_step)
                writer.add_scalar("loss_c", loss_c.item(), global_step)
                writer.add_scalar("loss_f", loss_f.item(), global_step)