Ejemplo n.º 1
0
def train_loop(device, model, optimizer, data_loader, log_dir):

    global global_step, global_epoch

    criterion = torch.nn.MSELoss().to(device)

    while global_epoch < hp.epochs:
        running_loss = 0
        for i, (x, input_lengths, y) in enumerate(data_loader):

            # Sort by length
            sorted_lengths, indices = torch.sort(input_lengths.view(-1),
                                                 dim=0,
                                                 descending=True)
            sorted_lengths = sorted_lengths.long().numpy()
            # Get sorted batch
            x, y = x[indices], y[indices]

            x, y = x.to(device), y.to(device)
            y_hat = model(x, sorted_lengths)
            loss = criterion(y_hat, y)

            # calculate learning rate and update learning rate
            if hp.fixed_learning_rate:
                current_lr = hp.fixed_learning_rate
            elif hp.lr_schedule_type == 'step':
                current_lr = step_learning_rate_decay(hp.init_learning_rate,
                                                      global_step,
                                                      hp.step_gamma,
                                                      hp.lr_step_interval)
            else:
                current_lr = noam_learning_rate_decay(hp.init_learning_rate,
                                                      global_step,
                                                      hp.noam_warm_up_steps)

            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr

            # optimizer.zero_grad()                                               # gradient clear
            loss = loss / hp.accumulation_steps  # loss regularization
            loss.backward()  # BP for gradient
            if ((i + 1) % hp.accumulation_steps) == 0:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), hp.grad_norm)  # clip gradient norm
                optimizer.step()  # update parameters of net
                optimizer.zero_grad()

            running_loss += loss.item()
            avg_loss = running_loss / (i + 1)

            # saving checkpoint
            if global_step != 0 and global_step % hp.checkpoint_interval == 0:
                # pruner.prune(global_step)
                save_checkpoint(device, model, optimizer, global_step,
                                global_epoch, log_dir)
            global_step += 1

        print("epoch:{:4d}  [loss={:.5f}, avg_loss={:.5f}, current_lr={}]".
              format(global_epoch, running_loss, avg_loss, current_lr))
        global_epoch += 1
Ejemplo n.º 2
0
def train_loop(device, model, data_loader, optimizer, checkpoint_dir):
    """Main training loop.

    """
    # create loss and put on device
    if hp.input_type == 'raw':
        if hp.distribution == 'beta':
            criterion = beta_mle_loss
        elif hp.distribution == 'gaussian':
            criterion = gaussian_loss
    elif hp.input_type == 'mixture':
        criterion = discretized_mix_logistic_loss
    elif hp.input_type in ["bits", "mulaw"]:
        criterion = nll_loss
    else:
        raise ValueError("input_type:{} not supported".format(hp.input_type))

    

    global global_step, global_epoch, global_test_step
    while global_epoch < hp.nepochs:
        running_loss = 0
        for i, (x, m, y) in enumerate(tqdm(data_loader)):
            x, m, y = x.to(device), m.to(device), y.to(device)
            y_hat = model(x, m)
            y = y.unsqueeze(-1)
            loss = criterion(y_hat, y)
            # calculate learning rate and update learning rate
            if hp.fix_learning_rate:
                current_lr = hp.fix_learning_rate
            elif hp.lr_schedule_type == 'step':
                current_lr = step_learning_rate_decay(hp.initial_learning_rate, global_step, hp.step_gamma, hp.lr_step_interval)
            else:
                current_lr = noam_learning_rate_decay(hp.initial_learning_rate, global_step, hp.noam_warm_up_steps)
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr
            optimizer.zero_grad()
            loss.backward()
            # clip gradient norm
            nn.utils.clip_grad_norm_(model.parameters(), hp.grad_norm)
            optimizer.step()

            running_loss += loss.item()
            avg_loss = running_loss / (i+1)
            # saving checkpoint if needed
            if global_step != 0 and global_step % hp.save_every_step == 0:
                save_checkpoint(device, model, optimizer, global_step, checkpoint_dir, global_epoch)
            # evaluate model if needed
            if global_step != 0 and global_test_step !=True and global_step % hp.evaluate_every_step == 0:
                print("step {}, evaluating model: generating wav from mel...".format(global_step))
                evaluate_model(model, data_loader, checkpoint_dir)
                print("evaluation finished, resuming training...")

            # reset global_test_step status after evaluation
            if global_test_step is True:
                global_test_step = False
            global_step += 1
        
        print("epoch:{}, running loss:{}, average loss:{}, current lr:{}".format(global_epoch, running_loss, avg_loss, current_lr))
        global_epoch += 1
Ejemplo n.º 3
0
    def train(self, global_step=0, global_epoch=1):
        while global_epoch < self.epoch:
            running_loss = 0.
            for step, (melX, melY, lengths) in enumerate(tqdm(self.train_loader)):
                self.model.train()

                # Learn rate scheduler
                current_lr = noam_learning_rate_decay(self.args.learn_rate, global_step)
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = current_lr
                self.optimizer.zero_grad()

                # Transform data to CUDA device
                melX = melX.to(self.device)
                melY = melY.to(self.device)
                lengths = lengths.to(self.device)

                target_mask = sequence_mask(lengths, max_len=melY.size(1)).unsqueeze(-1)

                # Apply model
                melX_output = self.model(melX) # TODO : code model

                # Losses
                mel_l1_loss, mel_binary_div = self.spec_loss(melX_output, melY, target_mask)
                loss = (1 - self.w) * mel_l1_loss + self.w * mel_binary_div

                # Update
                loss.backward()
                self.optimizer.step()
                # Logs
                self.writer.add_scalar("loss", float(loss.item()), global_step)
                self.writer.add_scalar("mel_l1_loss", float(mel_l1_loss.item()), global_step)
                self.writer.add_scalar("mel_binary_div_loss", float(mel_binary_div.item()), global_step)
                self.writer.add_scalar("learning rate", current_lr, global_step)

                global_step += 1
                running_loss += loss.item()

            if (global_epoch % self.checkpoint_interval == 0):
                self.save_checkpoint(global_step, global_epoch)
            if global_epoch % self.eval_interval == 0:
                self.save_states(global_epoch, melX_output, melX, melY, lengths)
            self.eval_model(global_epoch)
            avg_loss = running_loss / len(self.train_loader)
            self.writer.add_scalar("train loss (per epoch)", avg_loss, global_epoch)
            print("Train Loss: {}".format(avg_loss))
            global_epoch += 1
Ejemplo n.º 4
0
def train_loop(device, model, optimizer, data_loader, log_dir):

    global global_step, global_epoch

    criterion = torch.nn.MSELoss().to(device)

    while global_epoch < hp.epochs:
        running_loss = 0
        for i, (x, y) in enumerate(data_loader):

            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)

            # calculate learning rate and update learning rate
            if hp.fixed_learning_rate:
                current_lr = hp.fixed_learning_rate
            elif hp.lr_schedule_type == 'step':
                current_lr = step_learning_rate_decay(hp.init_learning_rate, global_step,
                                                      hp.step_gamma, hp.lr_step_interval)
            else:
                current_lr = noam_learning_rate_decay(hp.init_learning_rate, global_step,
                                                      hp.noam_warm_up_steps)

            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr

            optimizer.zero_grad()                                             # gradient clear
            loss.backward()                                                   # BP for gradient
            torch.nn.utils.clip_grad_norm_(model.parameters(), hp.grad_norm)  # clip gradient norm
            optimizer.step()                                                  # update weight parameter

            running_loss += loss.item()
            avg_loss = running_loss / (i+1)

            # saving checkpoint
            if global_step != 0 and global_step % hp.checkpoint_interval == 0:
                # pruner.prune(global_step)
                save_checkpoint(device, model, optimizer, global_step, global_epoch, log_dir)
            global_step += 1

        print("epoch:{:4d}  [loss={:.5f}, avg_loss={:.5f}, current_lr={}]".format(global_epoch,
                                                                                  running_loss, avg_loss,
                                                                                  current_lr))
        global_epoch += 1
Ejemplo n.º 5
0
def get_learning_rate(global_step, n_iters):
    if hp.fix_learning_rate:
        current_lr = hp.fix_learning_rate
    elif hp.lr_schedule_type == 'step':
        current_lr = step_learning_rate_decay(hp.initial_learning_rate, 
                    global_step, hp.step_gamma, hp.lr_step_interval)
    elif hp.lr_schedule_type == 'one-cycle':
        max_iters = n_iters*hp.nepochs
        cycle_width = int(max_iters*hp.fine_tune)
        step_size = cycle_width//2
        if global_step < cycle_width:
            cycle = np.floor(1 + global_step/(2*step_size))
            x = abs(global_step/step_size - 2*cycle + 1)
            current_lr = hp.min_lr + (hp.max_lr - hp.min_lr)*max(0, (1-x))
        else:
            x = (max_iters - global_step)/(max_iters - cycle_width)
            current_lr = 0.01*hp.min_lr + 0.99*hp.min_lr*x
    else:
        current_lr = noam_learning_rate_decay(hp.initial_learning_rate, 
                    global_step, hp.noam_warm_up_steps)
    return current_lr
Ejemplo n.º 6
0
    def train(self,
              train_seq2seq,
              train_postnet,
              global_epoch=1,
              global_step=0):
        while global_epoch < self.epoch:
            running_loss = 0.
            running_linear_loss = 0.
            running_mel_loss = 0.
            for step, (ling, mel, linear, lengths,
                       speaker_ids) in enumerate(tqdm(self.train_loader)):
                self.model.train()
                ismultispeaker = speaker_ids is not None
                # Learn rate scheduler
                current_lr = noam_learning_rate_decay(
                    self.hparams.initial_learning_rate, global_step)
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = current_lr
                self.optimizer.zero_grad()

                # Transform data to CUDA device
                if train_seq2seq:
                    ling = ling.to(self.device)
                    mel = mel.to(self.device)
                if train_postnet:
                    linear = linear.to(self.device)
                lengths = lengths.to(self.device)
                speaker_ids = speaker_ids.to(
                    self.device) if ismultispeaker else None
                target_mask = sequence_mask(lengths,
                                            max_len=mel.size(1)).unsqueeze(-1)

                # Apply model
                if train_seq2seq and train_postnet:
                    _, mel_outputs, linear_outputs = self.model(
                        ling, mel, speaker_ids=speaker_ids)
                #elif train_seq2seq:
                #    mel_style = self.model.gst(tmel)
                #    style_embed = mel_style.expand_as(smel)
                #    mel_input = smel + style_embed
                #    mel_outputs = self.model.seq2seq(mel_input)
                #    linear_outputs = None
                #elif train_postnet:
                #    linear_outputs = self.model.postnet(smel)
                #    mel_outputs = None

                # Losses
                if train_seq2seq:
                    mel_l1_loss, mel_binary_div = self.spec_loss(
                        mel_outputs, mel, target_mask)
                    mel_loss = (1 -
                                self.w) * mel_l1_loss + self.w * mel_binary_div
                if train_postnet:
                    linear_l1_loss, linear_binary_div = self.spec_loss(
                        linear_outputs, linear, target_mask)
                    linear_loss = (
                        1 -
                        self.w) * linear_l1_loss + self.w * linear_binary_div

                # Combine losses
                if train_seq2seq and train_postnet:
                    loss = mel_loss + linear_loss
                elif train_seq2seq:
                    loss = mel_loss
                elif train_postnet:
                    loss = linear_loss

                # Update
                loss.backward()
                self.optimizer.step()
                # Logs
                if train_seq2seq:
                    self.writer.add_scalar("mel loss", float(mel_loss.item()),
                                           global_step)
                    self.writer.add_scalar("mel_l1_loss",
                                           float(mel_l1_loss.item()),
                                           global_step)
                    self.writer.add_scalar("mel_binary_div_loss",
                                           float(mel_binary_div.item()),
                                           global_step)
                if train_postnet:
                    self.writer.add_scalar("linear_loss",
                                           float(linear_loss.item()),
                                           global_step)
                    self.writer.add_scalar("linear_l1_loss",
                                           float(linear_l1_loss.item()),
                                           global_step)
                    self.writer.add_scalar("linear_binary_div_loss",
                                           float(linear_binary_div.item()),
                                           global_step)
                self.writer.add_scalar("loss", float(loss.item()), global_step)
                self.writer.add_scalar("learning rate", current_lr,
                                       global_step)

                global_step += 1
                running_loss += loss.item()
                running_linear_loss += linear_loss.item()
                running_mel_loss += mel_loss.item()

            if (global_epoch % self.checkpoint_interval == 0):
                self.save_checkpoint(global_step, global_epoch)
            if global_epoch % self.eval_interval == 0:
                self.save_states(global_epoch, mel_outputs, linear_outputs,
                                 ling, mel, linear, lengths)
            self.eval_model(global_epoch, train_seq2seq, train_postnet)
            avg_loss = running_loss / len(self.train_loader)
            avg_linear_loss = running_linear_loss / len(self.train_loader)
            avg_mel_loss = running_mel_loss / len(self.train_loader)
            self.writer.add_scalar("train loss (per epoch)", avg_loss,
                                   global_epoch)
            self.writer.add_scalar("train linear loss (per epoch)",
                                   avg_linear_loss, global_epoch)
            self.writer.add_scalar("train mel loss (per epoch)", avg_mel_loss,
                                   global_epoch)
            print("Train Loss: {}".format(avg_loss))
            global_epoch += 1
Ejemplo n.º 7
0
def train_loop(device, model, data_loader, optimizer, checkpoint_dir):

    # create loss and put on device
    if hp.input_type == 'raw':
        if hp.distribution == 'beta':
            criterion = beta_mle_loss
        elif hp.distribution == 'gaussian':
            criterion = gaussian_loss
    elif hp.input_type == 'mixture':
        criterion = discretized_mix_logistic_loss
    elif hp.input_type in ["bits", "mulaw"]:
        criterion = nll_loss
    else:
        raise ValueError("input_type:{} not supported".format(hp.input_type))

    # Pruner for reducing memory footprint
    layers = [(model.I, hp.sparsity_target),
              (model.rnn1, hp.sparsity_target_rnn),
              (model.fc1, hp.sparsity_target), (model.fc3, hp.sparsity_target)
              ]  #(model.fc2,hp.sparsity_target),
    pruner = Pruner(layers, hp.start_prune, hp.prune_steps, hp.sparsity_target)

    global global_step, global_epoch, global_test_step
    while global_epoch < hp.nepochs:
        running_loss = 0
        for i, (x, m, y) in enumerate(tqdm(data_loader)):
            x, m, y = x.to(device), m.to(device), y.to(device)
            y_hat = model(x, m)
            y = y.unsqueeze(-1)
            loss = criterion(y_hat, y)
            # calculate learning rate and update learning rate
            if hp.fix_learning_rate:
                current_lr = hp.fix_learning_rate
            elif hp.lr_schedule_type == 'step':
                current_lr = step_learning_rate_decay(hp.initial_learning_rate,
                                                      global_step,
                                                      hp.step_gamma,
                                                      hp.lr_step_interval)
            else:
                current_lr = noam_learning_rate_decay(hp.initial_learning_rate,
                                                      global_step,
                                                      hp.noam_warm_up_steps)
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr
            optimizer.zero_grad()
            loss.backward()
            grad_norm = nn.utils.clip_grad_norm_(model.parameters(),
                                                 hp.grad_norm)
            optimizer.step()
            num_pruned, z = pruner.prune(global_step)

            running_loss += loss.item()
            avg_loss = running_loss / (i + 1)

            writer.add_scalar("loss", float(loss.item()), global_step)
            writer.add_scalar("avg_loss", float(avg_loss), global_step)
            writer.add_scalar("learning_rate", float(current_lr), global_step)
            writer.add_scalar("grad_norm", float(grad_norm), global_step)
            writer.add_scalar("num_pruned", float(num_pruned), global_step)
            writer.add_scalar("fraction_pruned", z, global_step)

            # saving checkpoint if needed
            if global_step != 0 and global_step % hp.save_every_step == 0:
                pruner.prune(global_step)
                save_checkpoint(device, model, optimizer, global_step,
                                checkpoint_dir, global_epoch)
            # evaluate model if needed
            if global_step != 0 and global_test_step != True and global_step % hp.evaluate_every_step == 0:
                pruner.prune(global_step)
                print("step {}, evaluating model: generating wav from mel...".
                      format(global_step))
                evaluate_model(model, data_loader, checkpoint_dir)
                print("evaluation finished, resuming training...")

            # reset global_test_step status after evaluation
            if global_test_step is True:
                global_test_step = False
            global_step += 1

        print(
            "epoch:{}, running loss:{}, average loss:{}, current lr:{}, num_pruned:{} ({}%)"
            .format(global_epoch, running_loss, avg_loss, current_lr,
                    num_pruned, z))
        global_epoch += 1