def training_procedure(dataset=None, num_gpus=0, output_directory='./train', epochs=1000, learning_rate=1e-4, batch_size=12, checkpointing=True, checkpoint_path="./checkpoints", seed=2019, params = [96, 6, 24, 3, 8, 2, [1,2], 96, 3], use_gpu=True, gen_tests=False, mname='model', validation_patience=10):
	params.append(use_gpu)
	torch.manual_seed(seed)
	if use_gpu:
		torch.cuda.manual_seed(seed)

	if checkpointing and not os.path.isdir(checkpoint_path[2:]): os.mkdir(checkpoint_path[2:])
	criterion = WaveGlowLoss()
	model = WaveGlow(*params)
	if use_gpu:
		model.cuda()

	valid_context, valid_forecast = dataset.valid_data()
	valid_forecast = set_gpu_tensor(valid_forecast, use_gpu)
	valid_context = set_gpu_tensor(valid_context, use_gpu)

	optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
	model.train()
	loss_iteration = []
	end_training = False
	best_validation = np.inf; validation_streak = 0
	for epoch in range(epochs):
		if end_training: break
		iteration = 0
		print("Epoch: %d/%d" % (epoch+1, epochs))
		avg_loss = []
		while(dataset.epoch_end):
			context, forecast = dataset.sample(batch_size)
			forecast = set_gpu_train_tensor(forecast, use_gpu)
			context = set_gpu_train_tensor(context, use_gpu)
			z, log_s_list, log_det_w_list, early_out_shapes = model(forecast, context)

			loss = criterion((z, log_s_list, log_det_w_list))
			reduced_loss = loss.item()
			loss_iteration.append(reduced_loss)
			optimizer.zero_grad()
			loss.backward()
			avg_loss.append(reduced_loss)
			optimizer.step()
			print("Epoch [%d/%d] on iteration %d with loss %.4f" % (epoch+1, epochs, iteration, reduced_loss))
			iteration += 1

		epoch_loss = sum(avg_loss)/len(avg_loss)
		validation_loss = get_validation_loss(model, criterion, valid_context, valid_forecast)
		print("Epoch [%d/%d] had training loss: %.4f and validation_loss: %.4f" % (epoch+1, epochs, epoch_loss, validation_loss))
		
		if best_validation > validation_loss:
			print("Validation loss improved to %.5f" % validation_loss)
			best_validation = validation_loss
			if gen_tests: generate_tests(dataset, model, 5, 96, use_gpu, str(epoch+1), mname=mname)
			if checkpointing:
				checkpoint_path = "%s/%s/epoch-%d_loss-%.4f" % (output_directory, mname, epoch, validation_loss)
				save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path, use_gpu)

			validation_streak = 0
		else:
			validation_streak += 1
		dataset.epoch_end = True

		if validation_streak == validation_patience: end_training = True

	if checkpointing:
		model iteration = load_checkpoint(checkpoint_path, model)
		
	test_context, test_forecast = dataset.test_data()
	test_loss, test_mse = get_test_loss_and_mse(model, criterion, test_context, test_forecast, use_gpu)

	if not checkpointing:
		checkpoint_path = "%s/%s/finalmodel_epoch-%d_testloss-%.4f_testmse_%.4f" % (output_directory, mname, epoch, test_loss, test_mse)
		save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path, use_gpu)
	
	print("Test loss for this model is %.5f, mse loss: %.5f" % (test_loss, test_mse))

	plt.figure()
	plt.plot(range(len(loss_iteration)), np.log10(np.array(loss_iteration)+1.0))
	plt.xlabel('iteration')
	plt.ylabel('log10 of loss')
	plt.savefig('%s/%s/total_loss_graph.png' % (output_directory, mname))
	plt.close()
	return test_loss, model
Пример #2
0
def training(dataset=None,
             num_gpus=0,
             output_directory='./train',
             epochs=1000,
             learning_rate=1e-4,
             batch_size=12,
             checkpointing=True,
             checkpoint_path="./checkpoints",
             seed=2019,
             params=[96, 6, 24, 3, 8, 2, [1, 2], 96, 3],
             use_gpu=True,
             gen_tests=True):
    print("#############")
    print(use_gpu)
    params.append(use_gpu)
    torch.manual_seed(seed)
    if use_gpu:
        torch.cuda.manual_seed(seed)

    if not os.path.isdir(output_directory[2:]): os.mkdir(output_directory[2:])
    if checkpointing and not os.path.isdir(checkpoint_path[2:]):
        os.mkdir(checkpoint_path[2:])
    criterion = WaveGlowLoss()
    model = WaveGlow(*params)
    if use_gpu:
        model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # iteration = 0
    # if checkpoint_path != "":
    # model, optimizer, iteration = load_checkpoint(checkpoint_path, model, optimizer)

    # iteration += 1

    model.train()
    loss_iteration = []
    for epoch in range(epochs):
        iteration = 0
        print("Epoch: %d/%d" % (epoch + 1, epochs))
        avg_loss = []
        while (dataset.epoch_end):
            # model.zero_grad()
            context, forecast = dataset.sample(batch_size)

            if use_gpu:
                forecast = torch.autograd.Variable(
                    torch.cuda.FloatTensor(forecast))
                context = torch.autograd.Variable(
                    torch.cuda.FloatTensor(context))
            else:
                forecast = torch.autograd.Variable(torch.FloatTensor(forecast))
                context = torch.autograd.Variable(torch.FloatTensor(context))

            z, log_s_list, log_det_w_list, early_out_shapes = model(
                forecast, context)

            loss = criterion((z, log_s_list, log_det_w_list))
            reduced_loss = loss.item()
            loss_iteration.append(reduced_loss)
            optimizer.zero_grad()
            loss.backward()
            avg_loss.append(reduced_loss)
            optimizer.step()
            # print("On iteration %d with loss %.4f" % (iteration, reduced_loss))
            iteration += 1
            # if (checkpointing and (iteration % iters_per_checkpoint == 0)):

        if gen_tests:
            generate_tests(dataset, model, 5, 96, use_gpu, str(epoch + 1))
        epoch_loss = sum(avg_loss) / len(avg_loss)
        if checkpointing:
            checkpoint_path = "%s/waveglow_epoch-%d_%.4f" % (output_directory,
                                                             epoch, epoch_loss)
            save_checkpoint(model, optimizer, learning_rate, iteration,
                            checkpoint_path, use_gpu)

        print("\tLoss: %.3f" % loss)
        dataset.epoch_end = True
    plt.figure()
    plt.semilogy(range(len(loss_iteration)), np.array(loss_iteration))
    # plt.plot(range(len(loss_iteration)), np.log10(np.array(loss_iteration)+1.0))
    plt.xlabel('iteration')
    plt.ylabel('loss')
    # plt.savefig('total_loss_graph.png')
    # plt.close()
    return model