def train(self): self.model.train() dataset = SpeechDataset(N_CLASS, SLICE_LENGTH, FRAME_LENGTH, FRAME_STRIDE, TEST_SIZE, self.device) dataset.init_dataset(test_mode=False) data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) optimizer = torch.optim.Adam(self.model.parameters(), lr=LEARNING_RATE) for epoch in range(MAX_EPOCHS): for i, data in enumerate(data_loader): x, y, cond = data pred_y = self.model(x, cond) loss = F.cross_entropy(pred_y, y) optimizer.zero_grad() loss.backward() clip_grad_norm_(self.model.parameters(), MAX_NORM) optimizer.step() if i % PRINT_FREQ == 0: self.logger.info( 'epoch: %d, step:%d, tot_step:%d, loss: %f' % (epoch, i, self.tot_steps, loss.item())) if i % VALID_FREQ == 0: self.validate() self.model.eval() self.tot_steps += 1 if self.tot_steps % 100 == 0: self.save_model()
def validate(self): self.model.eval() dataset = SpeechDataset(N_CLASS, SLICE_LENGTH, FRAME_LENGTH, FRAME_STRIDE, TEST_SIZE, self.device) dataset.init_dataset(test_mode=True) data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False) res = [] for i, data in enumerate(data_loader): if i == MAX_VALID: break x, y, cond = data pred_y = self.model(x, cond) loss = F.cross_entropy(pred_y.squeeze(), y.squeeze()) res.append(loss.item()) self.logger.info('valid loss: ' + str(sum(res) / len(res)))
def generate(self): self.model.eval() dataset = SpeechDataset(N_CLASS, SLICE_LENGTH, FRAME_LENGTH, FRAME_STRIDE, TEST_SIZE, self.device) dataset.init_dataset(test_mode=True) data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False) for i, data in enumerate(data_loader): if i == MAX_GENERATE: break _, _, cond = data res = self.model.generate(cond, MAX_GENERATE_LENGTH) res = dequantize_signal(res, N_CLASS) for j in range(res.shape[0]): librosa.output.write_wav( './samples/sample%d.wav' % (self.sample_count), res[j], SAMPLE_RATE) self.sample_count += 1