示例#1
0
def test(args):
    """
    Training
    """

    ##  ~~~~~~~~~~~~~~~~~~~
    ##   Initial settings
    ##  ~~~~~~~~~~~~~~~~~~~

    #   Input Variable
    nn.clear_parameters()  # Clear
    Input = nn.Variable([1, 3, 64, 64])  # Input
    Trues = nn.Variable([1, 1])  # True Value

    #   Network Definition
    Name = "CNN"  # Name of scope which includes network models (arbitrary)
    Output_test = network(Input, scope=Name, test=True)  # Network & Output
    Loss_test = F.mean(F.absolute_error(
        Output_test, Trues))  # Loss Function (Squared Error)

    #   Load data
    with nn.parameter_scope(Name):
        nn.load_parameters(
            os.path.join(args.model_save_path,
                         "network_param_{:04}.h5".format(args.epoch)))

    # Training Data Setting
    image_data, mos_data = dt.data_loader(test=True)
    batches = dt.create_batch(image_data, mos_data, 1)
    del image_data, mos_data

    truth = []
    result = []
    for j in range(batches.iter_n):
        Input.d, tures = next(batches)
        Loss_test.forward(clear_no_need_grad=True)
        result.append(Loss_test.d)
        truth.append(tures)

    result = np.array(result)
    truth = np.squeeze(np.array(truth))

    # Evaluation of performance
    mae = np.average(np.abs(result - truth))
    SRCC, p1 = stats.spearmanr(truth,
                               result)  # Spearman's Correlation Coefficient
    PLCC, p2 = stats.pearsonr(truth, result)

    #   Display
    print("\n Model Parameter [epoch={0}]".format(args.epoch))
    print(" Mean Absolute Error with Truth: {0:.4f}".format(mae))
    print(" Speerman's Correlation Coefficient: {0:.3f}".format(SRCC))
    print(" Pearson's Linear Correlation Coefficient: {0:.3f}".format(PLCC))
示例#2
0
def train(args):
    """
    Training
    """

    ##  ~~~~~~~~~~~~~~~~~~~
    ##   Initial settings
    ##  ~~~~~~~~~~~~~~~~~~~

    #   Input Variable
    nn.clear_parameters()  #   Clear
    Input = nn.Variable([args.batch_size, 3, 64, 64])  #   Input
    Trues = nn.Variable([args.batch_size, 1])  #   True Value

    #   Network Definition
    Name = "CNN"  #   Name of scope which includes network models (arbitrary)
    Output = network(Input, scope=Name)  #   Network & Output
    Output_test = network(Input, scope=Name, test=True)

    #   Loss Definition
    Loss = F.mean(F.absolute_error(Output,
                                   Trues))  #   Loss Function (Squared Error)
    Loss_test = F.mean(F.absolute_error(Output_test, Trues))

    #   Solver Setting
    solver = S.AMSBound(args.learning_rate)  #   Adam is used for solver
    with nn.parameter_scope(
            Name):  #   Get updating parameters included in scope
        solver.set_parameters(nn.get_parameters())

    #   Training Data Setting
    image_data, mos_data = dt.data_loader()
    batches = dt.create_batch(image_data, mos_data, args.batch_size)
    del image_data, mos_data

    #   Test Data Setting
    image_data, mos_data = dt.data_loader(test=True)
    batches_test = dt.create_batch(image_data, mos_data, args.batch_size)
    del image_data, mos_data

    ##  ~~~~~~~~~~~~~~~~~~~
    ##   Learning
    ##  ~~~~~~~~~~~~~~~~~~~
    print('== Start Training ==')

    bar = tqdm(total=args.epoch - args.retrain, leave=False)
    bar.clear()
    loss_disp = None
    SRCC = None

    #   Load data
    if args.retrain > 0:
        with nn.parameter_scope(Name):
            print('Retrain from {0} Epoch'.format(args.retrain))
            nn.load_parameters(
                os.path.join(args.model_save_path,
                             "network_param_{:04}.h5".format(args.retrain)))
            solver.set_learning_rate(args.learning_rate /
                                     np.sqrt(args.retrain))

    ##  Training
    for i in range(args.retrain, args.epoch):

        bar.set_description_str('Epoch {0}:'.format(i + 1), refresh=False)
        if (loss_disp is not None) and (SRCC is not None):
            bar.set_postfix_str('Loss={0:.5f},  SRCC={1:.4f}'.format(
                loss_disp, SRCC),
                                refresh=False)
        bar.update(1)

        #   Shuffling
        batches.shuffle()
        batches_test.shuffle()

        ##  Batch iteration
        for j in range(batches.iter_n):

            #  Load Batch Data from Training data
            Input.d, Trues.d = next(batches)

            #  Update
            solver.zero_grad()  #   Initialize
            Loss.forward(clear_no_need_grad=True)  #   Forward path
            Loss.backward(clear_buffer=True)  #   Backward path
            solver.weight_decay(0.00001)  #   Weight Decay for stable update
            solver.update()

        ## Progress
        # Get result for Display
        Input.d, Trues.d = next(batches_test)
        Loss_test.forward(clear_no_need_grad=True)
        Output_test.forward()
        loss_disp = Loss_test.d
        SRCC, _ = stats.spearmanr(Output_test.d, Trues.d)

        # Display text
        # disp(i, batches.iter_n, Loss_test.d)

        ## Save parameters
        if ((i + 1) % args.model_save_cycle) == 0 or (i + 1) == args.epoch:
            bar.clear()
            with nn.parameter_scope(Name):
                nn.save_parameters(
                    os.path.join(args.model_save_path,
                                 'network_param_{:04}.h5'.format(i + 1)))
def main(args):

    # get datasets
    dataset = data.get_dataset(args.dataset,
                               args.split,
                               image_size=args.image_size,
                               data_dir=args.data_dir,
                               is_training=True)

    im_x = preprocess(dataset.x,
                      args.preprocessing_a,
                      image_size=args.image_size,
                      output_channels=args.num_channels)
    im_y = preprocess(dataset.y,
                      args.preprocessing_b,
                      image_size=args.image_size)

    im_batch_x, im_batch_y = data.create_batch([im_x, im_y],
                                               batch_size=args.batch_size,
                                               shuffle=args.shuffle,
                                               queue_size=2,
                                               min_queue_size=1)

    # build models

    transformed_x = model.transformer(im_batch_x,
                                      output_channels=dataset.num_classes,
                                      output_fn=None,
                                      scope='model/AtoB')
    transformed_y = model.transformer(im_batch_y,
                                      output_channels=args.num_channels,
                                      scope='model/BtoA')

    cycled_x = model.transformer(tf.nn.softmax(transformed_x),
                                 output_channels=args.num_channels,
                                 scope='model/BtoA',
                                 reuse=True)
    cycled_y = model.transformer(transformed_y,
                                 output_channels=dataset.num_classes,
                                 output_fn=None,
                                 scope='model/AtoB',
                                 reuse=True)

    # create loss functions

    cycle_loss_x = tf.losses.absolute_difference(im_batch_x,
                                                 cycled_x,
                                                 scope='cycle_loss_x')
    cycle_loss_y = tf.losses.softmax_cross_entropy(im_batch_y,
                                                   cycled_y,
                                                   scope='cycle_loss_y')

    transform_loss_xy = tf.losses.absolute_difference(
        im_batch_x, transformed_y, scope='transform_loss_xy')
    transform_loss_yx = tf.losses.softmax_cross_entropy(
        im_batch_y, transformed_x, scope='transform_loss_yx')

    total_loss = cycle_loss_x + cycle_loss_y + transform_loss_xy + transform_loss_yx

    optimizer = tf.train.AdamOptimizer(args.learning_rate, args.beta1,
                                       args.beta2, args.epsilon)

    inc_global_step = tf.assign_add(tf.train.get_or_create_global_step(), 1)
    tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, inc_global_step)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_tensor = optimizer.minimize(total_loss)

        # Set up train op to return loss
        with tf.control_dependencies([train_tensor]):
            train_op = tf.identity(total_loss, name='train_op')

    # set up logging

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES):
        summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
        summaries.add(tf.summary.histogram(variable.op.name, variable))

    color_map = np.array(
        list(map(lambda x: x.color,
                 labels[:dataset.num_classes]))).astype(np.float32)

    segmentation_y = postprocess(tf.argmax(im_batch_y,
                                           -1), 'segmentation_to_rgb',
                                 dataset.num_classes, color_map)
    segmentation_transformed_x = postprocess(tf.argmax(transformed_x, -1),
                                             'segmentation_to_rgb',
                                             dataset.num_classes, color_map)
    segmentation_cycled_y = postprocess(tf.argmax(cycled_y,
                                                  -1), 'segmentation_to_rgb',
                                        dataset.num_classes, color_map)

    summaries.add(tf.summary.image('x', im_batch_x))
    summaries.add(tf.summary.image('y', segmentation_y))
    summaries.add(tf.summary.image('transformed_x',
                                   segmentation_transformed_x))
    summaries.add(tf.summary.image('transformed_y', transformed_y))
    summaries.add(tf.summary.image('cycled_x', cycled_x))
    summaries.add(tf.summary.image('cycled_y', segmentation_cycled_y))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    # create train loop

    if not os.path.isdir(args.output_dir):
        os.makedirs(args.output_dir)

    saver = tf.train.Saver(var_list=tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope='model'))
    checkpoint_path = os.path.join(args.output_dir, 'model.ckpt')
    writer = tf.summary.FileWriter(args.output_dir)

    with tf.Session() as sess:
        # Tensorflow initializations
        sess.run(tf.get_collection(tf.GraphKeys.TABLE_INITIALIZERS))
        tf.train.start_queue_runners(sess=sess)
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        last_log_time = 0
        last_save_time = 0
        for i in tqdm(range(args.num_batches)):
            if last_log_time < time.time() - args.log_every_n_seconds:
                last_log_time = time.time()
                summary, loss_val, global_step = sess.run(
                    [summary_op, train_op,
                     tf.train.get_global_step()])
                writer.add_summary(summary, global_step)
                writer.flush()
            else:
                loss_val, global_step = sess.run(
                    [train_op, tf.train.get_global_step()])

            if last_save_time < time.time() - args.save_every_n_seconds:
                last_save_time = time.time()
                saver.save(sess, checkpoint_path, global_step=global_step)

        saver.save(sess, checkpoint_path, global_step=args.num_batches)
示例#4
0
def main(args):

    # get datasets
    source_dataset = data.get_dataset(args.source, args.split)
    target_dataset = data.get_dataset(args.target, args.split)

    im_s = preprocess(source_dataset.x,
                      args.preprocessing,
                      image_size=args.image_size,
                      output_channels=args.output_channels)
    label_s = source_dataset.y

    im_t = preprocess(target_dataset.x,
                      args.preprocessing,
                      image_size=args.image_size,
                      output_channels=args.output_channels)
    label_t = target_dataset.y

    im_batch_s, label_batch_s, im_batch_t, label_batch_t = data.create_batch(
        [im_s, label_s, im_t, label_t],
        batch_size=args.batch_size,
        shuffle=args.shuffle)

    # build models

    transformed_s = model.transformer(im_batch_s, scope='model/s_to_t')
    transformed_t = model.transformer(im_batch_t, scope='model/t_to_s')

    cycled_s = model.transformer(transformed_s,
                                 scope='model/t_to_s',
                                 reuse=True)
    cycled_t = model.transformer(transformed_t,
                                 scope='model/s_to_t',
                                 reuse=True)

    # create loss functions

    cycle_loss_s = tf.losses.absolute_difference(im_batch_s,
                                                 cycled_s,
                                                 scope='cycle_loss_s')
    cycle_loss_t = tf.losses.absolute_difference(im_batch_t,
                                                 cycled_t,
                                                 scope='cycle_loss_t')

    total_loss = cycle_loss_s + cycle_loss_t

    optimizer = tf.train.AdamOptimizer(args.learning_rate, args.beta1,
                                       args.beta2, args.epsilon)

    inc_global_step = tf.assign_add(tf.train.get_or_create_global_step(), 1)
    tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, inc_global_step)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_tensor = optimizer.minimize(total_loss)

        # Set up train op to return loss
        with tf.control_dependencies([train_tensor]):
            train_op = tf.identity(total_loss, name='train_op')

    # set up logging

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES):
        summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
        summaries.add(tf.summary.histogram(variable.op.name, variable))

    summaries.add(tf.summary.image('source', im_batch_s))
    summaries.add(tf.summary.image('target', im_batch_t))
    summaries.add(tf.summary.image('source_transformed', transformed_s))
    summaries.add(tf.summary.image('target_transformed', transformed_t))
    summaries.add(tf.summary.image('source_cycled', cycled_s))
    summaries.add(tf.summary.image('target_cycled', cycled_t))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    # create train loop

    if not os.path.isdir(args.output_dir):
        os.makedirs(args.output_dir)

    saver = tf.train.Saver(var_list=tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope='model'))
    checkpoint_path = os.path.join(args.output_dir, 'model.ckpt')
    writer = tf.summary.FileWriter(args.output_dir)

    with tf.Session() as sess:
        # Tensorflow initializations
        sess.run(tf.get_collection(tf.GraphKeys.TABLE_INITIALIZERS))
        tf.train.start_queue_runners(sess=sess)
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        last_log_time = 0
        last_save_time = 0
        for i in tqdm(range(args.num_batches)):
            if last_log_time < time.time() - args.log_every_n_seconds:
                last_log_time = time.time()
                summary, loss_val, global_step = sess.run(
                    [summary_op, train_op,
                     tf.train.get_global_step()])
                writer.add_summary(summary, global_step)
                writer.flush()
            else:
                loss_val, global_step = sess.run(
                    [train_op, tf.train.get_global_step()])

            if last_save_time < time.time() - args.save_every_n_seconds:
                last_save_time = time.time()
                saver.save(sess, checkpoint_path, global_step=global_step)

        saver.save(sess, checkpoint_path, global_step=args.num_batches)
def train(args):

    ##  Sub-functions
    ## ---------------------------------
    ## Save Models
    def save_models(epoch_num, losses):

        # save generator parameter
        with nn.parameter_scope('Wave-U-Net'):
            nn.save_parameters(
                os.path.join(args.model_save_path,
                             'param_{:04}.h5'.format(epoch_num + 1)))

        # save results
        np.save(
            os.path.join(args.model_save_path,
                         'losses_{:04}.npy'.format(epoch_num + 1)),
            np.array(losses))

    ## Load Models
    def load_models(epoch_num, gen=True, dis=True):

        # load generator parameter
        with nn.parameter_scope('Wave-U-Net'):
            nn.load_parameters(
                os.path.join(args.model_save_path,
                             'param_{:04}.h5'.format(args.epoch_from)))

    ## Update parameters
    class updating:
        def __init__(self):
            self.scale = 8 if args.halfprec else 1

        def __call__(self, solver, loss):
            solver.zero_grad()  # initialize
            loss.forward(clear_no_need_grad=True)  # calculate forward
            loss.backward(self.scale, clear_buffer=True)  # calculate backward
            #solver.scale_grad(1. / self.scale)                # scaling
            solver.update()  # update

    ##  Inital Settings
    ## ---------------------------------

    ##  Create network
    #   Clear
    nn.clear_parameters()
    #   Variables
    noisy = nn.Variable([args.batch_size, 1, 16384], need_grad=False)  # Input
    clean = nn.Variable([args.batch_size, 1, 16384], need_grad=False)  # Desire

    # Build Network
    # K=2, C=1
    target_1, target_2 = Wave_U_Net(noisy)

    # Mean Squared Error
    loss = (F.mean(F.squared_error(clean, target_1)) +
            F.mean(F.squared_error(noisy - clean, target_2))) / 2.

    # Optimizer: Adam
    solver = S.Adam(args.learning_rate)

    # set parameter
    with nn.parameter_scope('Wave-U-Net'):
        solver.set_parameters(nn.get_parameters())

    ##  Load data & Create batch
    clean_data, noisy_data = dt.data_loader()
    batches = dt.create_batch(clean_data, noisy_data, args.batch_size)
    del clean_data, noisy_data

    ##  Initial settings for sub-functions
    fig = figout()
    disp = display(args.epoch_from, args.epoch, batches.batch_num)
    upd = updating()

    ##  Train
    ##----------------------------------------------------

    print('== Start Training ==')

    ##  Load "Pre-trained" parameters
    if args.epoch_from > 0:
        print(' Retrain parameter from pre-trained network')
        load_models(args.epoch_from)
        losses = np.load(
            os.path.join(args.model_save_path,
                         'losses_{:04}.npy'.format(args.epoch_from)))

        ## Create loss loggers
        point = args.epoch_from * ((batches.batch_num + 1) // 10)
        loss_len = (args.epoch - args.epoch_from) * (
            (batches.batch_num + 1) // 10)
        losses = np.append(losses, np.zeros(loss_len))
    else:
        losses = []
        ## Create loss loggers
        point = len(losses)
        loss_len = (args.epoch - args.epoch_from) * (
            (batches.batch_num + 1) // 10)
        losses = np.append(losses, np.zeros(loss_len))

    ##  Training
    for i in range(args.epoch_from, args.epoch):

        print('')
        print(' =========================================================')
        print('  Epoch :: {0}/{1}'.format(i + 1, args.epoch))
        print(' =========================================================')
        print('')

        batches.shuffle()

        #  Batch iteration
        for j in range(batches.batch_num):
            print('  Train (Epoch. {0}) - {1}/{2}'.format(
                i + 1, j + 2, batches.batch_num))

            ##  Batch setting
            clean.d, noisy.d = batches.next(j)

            ##  Updating
            upd(solver, loss)  # update Generator

            ##  Display
            if (j) % 100 == 0:
                # Get result for Display
                target_1.forward(clear_no_need_grad=True)
                target_2.forward(clear_no_need_grad=True)

                # Display text
                disp(i, j, loss.d)

                # Data logger
                losses[point] = loss.d
                point = point + 1

                # Plot
                fig.waveform_1(noisy.d[0, 0, :], target_1.d[0, 0, :],
                               clean.d[0, 0, :])
                fig.waveform_2(noisy.d[0, 0, :], target_2.d[0, 0, :],
                               clean.d[0, 0, :])
                fig.loss(losses[0:point - 1])
                pg.QtGui.QApplication.processEvents()

        ## Save parameters
        if ((i + 1) % args.model_save_cycle) == 0:
            save_models(i, losses)  # save model
            # fig.save(os.path.join(args.model_save_path, 'plot_{:04}.pdf'.format(i + 1))) # save fig
            exporter = pg.exporters.ImageExporter(fig.win.scene(
            ))  # exportersの直前に pg.QtGui.QApplication.processEvents() を呼ぶ!
            exporter.export(
                os.path.join(args.model_save_path,
                             'plot_{:04}.png'.format(i + 1)))  # save fig

    ## Save parameters (Last)
    save_models(args.epoch - 1, losses)
    exporter = pg.exporters.ImageExporter(fig.win.scene(
    ))  # exportersの直前に pg.QtGui.QApplication.processEvents() を呼ぶ!
    exporter.export(
        os.path.join(args.model_save_path,
                     'plot_{:04}.png'.format(i + 1)))  # save fig
示例#6
0
def train(args):

    ##  Sub-functions
    ## ---------------------------------
    ## Save Models
    def save_models(epoch_num, cle_disout, fake_disout, losses_gen, losses_dis, losses_ae):

        # save generator parameter
        with nn.parameter_scope("gen"):
            nn.save_parameters(os.path.join(args.model_save_path, 'generator_param_{:04}.h5'.format(epoch_num + 1)))

        # save discriminator parameter
        with nn.parameter_scope("dis"):
            nn.save_parameters(os.path.join(args.model_save_path, 'discriminator_param_{:04}.h5'.format(epoch_num + 1)))

        # save results
        np.save(os.path.join(args.model_save_path, 'disout_his_{:04}.npy'.format(epoch_num + 1)), np.array([cle_disout, fake_disout]))
        np.save(os.path.join(args.model_save_path, 'losses_gen_{:04}.npy'.format(epoch_num + 1)), np.array(losses_gen))
        np.save(os.path.join(args.model_save_path, 'losses_dis_{:04}.npy'.format(epoch_num + 1)), np.array(losses_dis))
        np.save(os.path.join(args.model_save_path, 'losses_ae_{:04}.npy'.format(epoch_num + 1)), np.array(losses_ae))

    ## Load Models
    def load_models(epoch_num, gen=True, dis=True):

        # load generator parameter
        with nn.parameter_scope("gen"):
            nn.load_parameters(os.path.join(args.model_save_path, 'generator_param_{:04}.h5'.format(args.epoch_from)))

        # load discriminator parameter
        with nn.parameter_scope("dis"):
            nn.load_parameters(os.path.join(args.model_save_path, 'discriminator_param_{:04}.h5'.format(args.epoch_from)))

    ## Update parameters
    class updating:

        def __init__(self):
            self.scale = 8 if args.halfprec else 1

        def __call__(self, solver, loss):
            solver.zero_grad()                                  # initialize
            loss.forward(clear_no_need_grad=True)               # calculate forward
            loss.backward(self.scale, clear_buffer=True)      # calculate backward
            solver.scale_grad(1. / self.scale)                # scaling
            solver.weight_decay(args.weight_decay * self.scale) # decay
            solver.update()                                     # update


    ##  Inital Settings
    ## ---------------------------------

    ##  Create network
    #   Clear
    nn.clear_parameters()
    #   Variables
    noisy 		= nn.Variable([args.batch_size, 1, 16384], need_grad=False)  # Input
    clean 		= nn.Variable([args.batch_size, 1, 16384], need_grad=False)  # Desire
    z           = nn.Variable([args.batch_size, 1024, 8], need_grad=False)   # Random Latent Variable
    #   Generator
    genout = Generator(noisy, z)                       # Predicted Clean
    genout.persistent = True                # Not to clear at backward
    loss_gen 	= Loss_gen(genout, clean, Discriminator(noisy, genout))
    loss_ae     = F.mean(F.absolute_error(genout, clean))
    #   Discriminator
    fake_dis 	= genout.get_unlinked_variable(need_grad=True)
    cle_disout  = Discriminator(noisy, clean)
    fake_disout  = Discriminator(noisy, fake_dis)
    loss_dis    = Loss_dis(Discriminator(noisy, clean),Discriminator(noisy, fake_dis))

    ##  Solver
    # RMSprop.
    # solver_gen = S.RMSprop(args.learning_rate_gen)
    # solver_dis = S.RMSprop(args.learning_rate_dis)
    # Adam
    solver_gen = S.Adam(args.learning_rate_gen)
    solver_dis = S.Adam(args.learning_rate_dis)
    # set parameter
    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    ##  Load data & Create batch
    clean_data, noisy_data = dt.data_loader()
    batches     = dt.create_batch(clean_data, noisy_data, args.batch_size)
    del clean_data, noisy_data

    ##  Initial settings for sub-functions
    fig     = figout()
    disp    = display(args.epoch_from, args.epoch, batches.batch_num)
    upd     = updating()

    ##  Train
    ##----------------------------------------------------

    print('== Start Training ==')

    ##  Load "Pre-trained" parameters
    if args.epoch_from > 0:
        print(' Retrain parameter from pre-trained network')
        load_models(args.epoch_from, dis=False)
        losses_gen  = np.load(os.path.join(args.model_save_path, 'losses_gen_{:04}.npy'.format(args.epoch_from)))
        losses_dis  = np.load(os.path.join(args.model_save_path, 'losses_dis_{:04}.npy'.format(args.epoch_from)))
        losses_ae   = np.load(os.path.join(args.model_save_path, 'losses_ae_{:04}.npy'.format(args.epoch_from)))
    else:
        losses_gen  = []
        losses_ae   = []
        losses_dis  = []

    ## Create loss loggers
    point       = len(losses_gen)
    loss_len    = (args.epoch - args.epoch_from) * ((batches.batch_num+1)//10)
    losses_gen  = np.append(losses_gen, np.zeros(loss_len))
    losses_ae   = np.append(losses_ae, np.zeros(loss_len))
    losses_dis  = np.append(losses_dis, np.zeros(loss_len))

    ##  Training
    for i in range(args.epoch_from, args.epoch):

        print('')
        print(' =========================================================')
        print('  Epoch :: {0}/{1}'.format(i + 1, args.epoch))
        print(' =========================================================')
        print('')

        #  Batch iteration
        for j in range(batches.batch_num):
            print('  Train (Epoch. {0}) - {1}/{2}'.format(i+1, j+1, batches.batch_num))

            ##  Batch setting
            clean.d, noisy.d = batches.next(j)
            #z.d = np.random.randn(*z.shape)
            z.d = np.zeros(z.shape)

            ##  Updating
            upd(solver_gen, loss_gen)       # update Generator
            upd(solver_dis, loss_dis)       # update Discriminator

            ##  Display
            if (j+1) % 10 == 0:
                # Get result for Display
                cle_disout.forward()
                fake_disout.forward()
                loss_ae.forward(clear_no_need_grad=True)

                # Display text
                disp(i, j, loss_gen.d, loss_dis.d, loss_ae.d)

                # Data logger
                losses_gen[point] = loss_gen.d
                losses_ae[point]  = loss_ae.d
                losses_dis[point] = loss_dis.d
                point = point + 1

                # Plot
                fig.waveform(noisy.d[0,0,:], genout.d[0,0,:], clean.d[0,0,:])
                fig.loss(losses_gen[0:point-1], losses_ae[0:point-1], losses_dis[0:point-1])
                fig.histogram(cle_disout.d, fake_disout.d)
                pg.QtGui.QApplication.processEvents()


        ## Save parameters
        if ((i+1) % args.model_save_cycle) == 0:
            save_models(i, cle_disout.d, fake_disout.d, losses_gen[0:point-1], losses_dis[0:point-1], losses_ae[0:point-1])  # save model
            exporter = pg.exporters.ImageExporter(fig.win.scene())  # Call pg.QtGui.QApplication.processEvents() before exporters!!
            exporter.export(os.path.join(args.model_save_path, 'plot_{:04}.png'.format(i + 1))) # save fig

    ## Save parameters (Last)
    save_models(args.epoch-1, cle_disout.d, fake_disout.d, losses_gen, losses_dis, losses_ae)