def train(epoch, model, optimizer, scheduler, ema): global global_step epoch_loss = 0. running_loss = [0., 0., 0., 0., 0.] model.train() start_time = time.time() display_step = 100 for batch_idx, (x, _, c, _) in enumerate(train_loader): scheduler.step() global_step += 1 x, c = x.to(device), c.to(device) optimizer.zero_grad() x_rec, x_prior, loss_rec, loss_kl = model(x, c) stft_rec, stft_rec_log = stft(x_rec[:, 0, 1:]) stft_truth, stft_truth_log = stft(x[:, 0, 1:]) stft_prior, stft_prior_log = stft(x_prior[:, 0, 1:]) loss_frame_rec = criterion_l2(stft_rec, stft_truth) + criterion_l1( stft_rec_log, stft_truth_log) loss_frame_prior = criterion_l2(stft_prior, stft_truth) + criterion_l1( stft_prior_log, stft_truth_log) # KL annealing coefficient alpha = 1 / (1 + np.exp(-5e-5 * (global_step - 5e+5))) loss_rec, loss_kl = loss_rec.mean(), loss_kl.mean() loss_tot = loss_rec + loss_kl * alpha + loss_frame_rec + loss_frame_prior loss_tot.backward() nn.utils.clip_grad_norm_(model.parameters(), 10.) optimizer.step() if ema is not None: for name, param in model.named_parameters(): if name in ema.shadow: ema.update(name, param.data) running_loss[0] += loss_tot.item() / display_step running_loss[1] += loss_rec.item() / display_step running_loss[2] += loss_kl.item() / display_step running_loss[3] += loss_frame_rec.item() / display_step running_loss[4] += loss_frame_prior.item() / display_step epoch_loss += loss_tot.item() if (batch_idx + 1) % display_step == 0: end_time = time.time() print( 'Global Step : {}, [{}, {}] [Total Loss, Rec Loss, KL Loss, STFT Recon, STFT Prior)] : {}' .format(global_step, epoch, batch_idx + 1, np.array(running_loss))) print('{} Step Time : {}'.format(display_step, end_time - start_time)) start_time = time.time() running_loss = [0., 0., 0., 0., 0.] del loss_tot, loss_frame_rec, loss_frame_prior, loss_kl, loss_rec, x, c, x_rec, x_prior del stft_rec, stft_truth, stft_prior, stft_truth_log print('{} Epoch Training Loss : {:.4f}'.format( epoch, epoch_loss / (len(train_loader)))) return epoch_loss / len(train_loader)
def evaluate(model_t, model_s, ema=None): if ema is not None: model_s_ema = clone_as_averaged_model(model_s, ema) model_t.eval() model_s_ema.eval() running_loss = [0., 0., 0., 0.] epoch_loss = 0. display_step = 100 for batch_idx, (x, y, c, _) in enumerate(test_loader): x, y, c = x.to(device), y.to(device), c.to(device) q_0 = Normal(x.new_zeros(x.size()), x.new_ones(x.size())) z = q_0.sample() c_up = model_t.upsample(c) x_student, mu_s, logs_s = model_s_ema(z, c_up) mu_logs_t = model_t(x_student, c) if args.KL_type == 'pq': loss_t, loss_KL, loss_reg = criterion_t(mu_logs_t[:, 0:1, :-1], mu_logs_t[:, 1:, :-1], mu_s, logs_s) elif args.KL_type == 'qp': loss_t, loss_KL, loss_reg = criterion_t(mu_s, logs_s, mu_logs_t[:, 0:1, :-1], mu_logs_t[:, 1:, :-1]) stft_student = stft(x_student[:, 0, 1:], scale='linear') stft_truth = stft(x[:, 0, 1:], scale='linear') loss_frame = criterion_frame(stft_student, stft_truth.detach()) loss_tot = loss_t + loss_frame running_loss[0] += loss_tot.item() / display_step running_loss[1] += loss_KL.item() / display_step running_loss[2] += loss_reg.item() / display_step running_loss[3] += loss_frame.item() / display_step epoch_loss += loss_tot.item() if (batch_idx + 1) % display_step == 0: print('{} [Total, KL, Reg, Frame Loss] : {}'.format( batch_idx + 1, np.array(running_loss))) running_loss = [0., 0., 0., 0.] del loss_tot, loss_frame, loss_KL, loss_reg, loss_t, x, y, c, c_up, stft_student, stft_truth, q_0, z del x_student, mu_s, logs_s, mu_logs_t epoch_loss /= len(test_loader) print('Evaluation Loss : {:.4f}'.format(epoch_loss)) del model_s_ema return epoch_loss
def evaluate(model, ema=None): if ema is not None: model_ema = clone_as_averaged_model(model, ema) model_ema.eval() running_loss = [0., 0., 0., 0., 0.] epoch_loss = 0. display_step = 100 for batch_idx, (x, _, c, _) in enumerate(test_loader): x, c = x.to(device), c.to(device) x_rec, x_prior, loss_rec, loss_kl = model(x, c) stft_rec, stft_rec_log = stft(x_rec[:, 0, 1:]) stft_truth, stft_truth_log = stft(x[:, 0, 1:]) stft_prior, stft_prior_log = stft(x_prior[:, 0, 1:]) loss_frame_rec = criterion_l2(stft_rec, stft_truth) + criterion_l1( stft_rec_log, stft_truth_log) loss_frame_prior = criterion_l2(stft_prior, stft_truth) + criterion_l1( stft_prior_log, stft_truth_log) # KL annealing coefficient alpha = 1 / (1 + np.exp(-5e-5 * (global_step - 1e+6))) loss_rec, loss_kl = loss_rec.mean(), loss_kl.mean() loss_tot = loss_rec + loss_kl * alpha + loss_frame_rec + loss_frame_prior if ema is not None: for name, param in model.named_parameters(): if name in ema.shadow: ema.update(name, param.data) running_loss[0] += loss_tot.item() / display_step running_loss[1] += loss_rec.item() / display_step running_loss[2] += loss_kl.item() / display_step running_loss[3] += loss_frame_rec.item() / display_step running_loss[4] += loss_frame_prior.item() / display_step epoch_loss += loss_tot.item() if (batch_idx + 1) % display_step == 0: print( 'Global Step : {}, [{}, {}] [Total Loss, Rec Loss, KL Loss, STFT Recon, STFT Prior)] : {}' .format(global_step, epoch, batch_idx + 1, np.array(running_loss))) running_loss = [0., 0., 0., 0., 0.] del loss_tot, loss_frame_rec, loss_frame_prior, loss_kl, loss_rec, x, c, x_rec, x_prior del stft_rec, stft_truth, stft_prior, stft_truth_log epoch_loss /= len(test_loader) print('Evaluation Loss : {:.4f}'.format(epoch_loss)) del model_ema return epoch_loss
def train(epoch, model_t, model_s, optimizer, ema): global global_step epoch_loss = 0.0 running_loss = [0.0, 0.0, 0.0, 0.0] model_t.eval() model_s.train() start_time = time.time() display_step = 100 for batch_idx, (x, y, c, _) in enumerate(train_loader): global_step += 1 if global_step == 200000: for param_group in optimizer.param_groups: param_group['lr'] *= 0.5 state['learning_rate'] = param_group['lr'] if global_step == 400000: for param_group in optimizer.param_groups: param_group['lr'] *= 0.5 state['learning_rate'] = param_group['lr'] if global_step == 600000: for param_group in optimizer.param_groups: param_group['lr'] *= 0.5 state['learning_rate'] = param_group['lr'] x, y, c = x.to(device), y.to(device), c.to(device) q_0 = Normal(x.new_zeros(x.size()), x.new_ones(x.size())) z = q_0.sample() optimizer.zero_grad() c_up = model_t.upsample(c) x_student, mu_s, logs_s = model_s( z, c_up) # q_T ~ N(mu_tot, logs_tot.exp_()) mu_logs_t = model_t(x_student, c) if args.KL_type == 'pq': loss_t, loss_KL, loss_reg = criterion_t(mu_logs_t[:, 0:1, :-1], mu_logs_t[:, 1:, :-1], mu_s, logs_s) elif args.KL_type == 'qp': loss_t, loss_KL, loss_reg = criterion_t(mu_s, logs_s, mu_logs_t[:, 0:1, :-1], mu_logs_t[:, 1:, :-1]) stft_student = stft(x_student[:, 0, 1:], scale='linear') stft_truth = stft(x[:, 0, 1:], scale='linear') loss_frame = criterion_frame(stft_student, stft_truth) loss_tot = loss_t + loss_frame loss_tot.backward() nn.utils.clip_grad_norm_(model_s.parameters(), 10.) optimizer.step() if ema is not None: for name, param in model_s.named_parameters(): if name in ema.shadow: ema.update(name, param.data) running_loss[0] += loss_tot.item() / display_step running_loss[1] += loss_KL.item() / display_step running_loss[2] += loss_reg.item() / display_step running_loss[3] += loss_frame.item() / display_step epoch_loss += loss_tot.item() if (batch_idx + 1) % display_step == 0: end_time = time.time() print( 'Global Step : {}, [{}, {}] [Total Loss, KL Loss, Reg Loss, Frame Loss] : {}' .format(global_step, epoch, batch_idx + 1, np.array(running_loss))) print('{} Step Time : {}'.format(display_step, end_time - start_time)) start_time = time.time() running_loss = [0.0, 0.0, 0.0, 0.0] del loss_tot, loss_frame, loss_KL, loss_reg, loss_t, x, y, c, c_up, stft_student, stft_truth, q_0, z del x_student, mu_s, logs_s, mu_logs_t print('{} Epoch Training Loss : {:.4f}'.format( epoch, epoch_loss / (len(train_loader)))) return epoch_loss / len(train_loader)