コード例 #1
0
		num_test = len(f.readlines())

	# Enqueue jobs
	for i in range(num_train):
		tasks.put(Task(train_list[i], DATA_SAVE_DIR, 'train', i))
	
	for i in range(num_test):
		tasks.put(Task(test_list[i], DATA_SAVE_DIR, 'test', i))

	# Add a poison pill for each consumer
	for i in range(num_consumers):
		tasks.put(None)
	
	wvn = WaveNet(input_dim=256+406+2, dilations=[1,2,4,8,16,32,64,128,256,512], filter_width=2)
	wvn.build()
	wvn.compile()
	wvn.plot()
	wvn.add_callbacks(os.path.join(CKPT_PATH,'weights.epoch001.{epoch:02d}.hdf5'), None)

	# Start 1st epoch training
	num_jobs = num_train
	train_files = []
	train_times = []
	while num_jobs:
		f = results_tr.get()
		train_files.append(f)
		# Model training
		start = T()
		wvn.fit_on_file(f)
		end = T()
		train_times.append(end-start)
コード例 #2
0
ファイル: train.py プロジェクト: jostosh/wavenet
def train(args):
    wavenet = WaveNet(regularize_coeff=args.regularize_coeff,
                      learning_rate=args.lr,
                      global_condition=args.global_cond,
                      dilation_stacks=args.dilation_stacks,
                      filter_width=args.filter_width,
                      quantization_channels=args.quantization_channels,
                      dilation_channels=args.dilation_channels,
                      global_cond_depth=args.global_cond_depth,
                      residual_channels=args.residual_channels,
                      dilation_pow2=args.dilation_pow2,
                      skip_channels=args.skip_channels)

    # Load the data
    logger.info("Loading SimpleWaveForms data")
    dataset = SimpleWaveForms(sequence_len=args.sequence_len,
                              freq_range=(args.freq_min, args.freq_max),
                              sample_freq=args.sample_freq)

    with tf.name_scope("InputPipeline"):
        (train_t0, train_t1, train_cond), (test_t0, test_t1,
                                           test_cond) = dataset.pipeline(
                                               args.train_size,
                                               args.test_size,
                                               batch_size=args.batch_size)

    # Compile graph with train data iterator
    train_pred, train_loss, optimize, train_logits, train_mse = wavenet.compile(
        train_t0, train_t1, train_cond, mode='train')

    # Compile graph with test data iterator
    test_pred, test_loss, test_logits, test_mse = wavenet.compile(test_t0,
                                                                  test_t1,
                                                                  test_cond,
                                                                  mode='test')

    logger.info("Layer overview")
    for name in wavenet.layers:
        logger.info(name)

    # Setup summaries
    logger.info("Setting up summaries")
    with tf.name_scope("ScalarSummaries"):
        tf.summary.scalar("CrossEntropyLoss", train_loss)
        tf.summary.scalar("MeanSquaredError", train_mse)
        summary_op = tf.summary.merge_all()

    # Do training
    logger.info("Creating session")
    with tf.Session() as sess:

        # Tensorboard writer
        logdir_train, logdir_test = next_logdir()
        logger.info("Logging train results at {}".format(logdir_train))
        logger.info("Logging test  results at {}".format(logdir_test))
        train_writer = tf.summary.FileWriter(logdir=logdir_train,
                                             graph=sess.graph)
        test_writer = tf.summary.FileWriter(logdir=logdir_test)

        # Initialize
        logger.info("Initializing variables")
        sess.run(tf.global_variables_initializer())

        # Loop lengths
        train_steps = int(np.ceil(args.train_size) / args.batch_size)
        test_steps = int(np.ceil(args.test_size) / args.batch_size)
        for epoch in range(args.num_epochs):

            # Train it
            logger.info("Current epoch: {}".format(
                str(epoch).zfill(int(np.log10(args.num_epochs * 10)))))
            if epoch == 0:
                logger.info(
                    "First epoch, graph initialization will take some time.")
            pbar = tqdm.trange(train_steps, desc="Train")
            loss_avg, mse_avg = 0.0, 0.0
            for i in pbar:
                if i % args.summary_interval == 0:
                    loss, _, summary_out, mse = sess.run(
                        [train_loss, optimize, summary_op, train_mse])
                    train_writer.add_summary(summary=summary_out,
                                             global_step=epoch * train_steps +
                                             i)
                    train_writer.flush()
                else:
                    loss, _, mse = sess.run([train_loss, optimize, train_mse])
                loss_avg = loss_avg + (loss - loss_avg) / (i + 1)
                mse_avg = mse_avg + (mse - mse_avg) / (i + 1)
                pbar.set_postfix(Loss=loss_avg, MSE=mse_avg)

            # Test it
            logger.info("Testing")
            if epoch == 0:
                logger.info(
                    "First epoch, graph initialization will take some time.")
            pbar = tqdm.trange(test_steps, desc="Test")
            loss_avg, mse_avg = 0.0, 0.0
            for i in pbar:
                loss, mse = sess.run([test_loss, test_mse])
                loss_avg = loss_avg + (loss - loss_avg) / (i + 1)
                mse_avg = mse_avg + (mse - mse_avg) / (i + 1)
                pbar.set_postfix(Loss=loss_avg, MSE=mse_avg)

            # Add a summary for average test loss and MSE
            loss_summary = tf.Summary.Value(
                tag="ScalarSummaries/CrossEntropyLoss", simple_value=loss_avg)
            mse_summary = tf.Summary.Value(
                tag="ScalarSummaries/MeanSquaredError", simple_value=mse_avg)
            test_writer.add_summary(
                summary=tf.Summary(value=[loss_summary, mse_summary]),
                global_step=(epoch + 1) * train_steps)
            test_writer.flush()