예제 #1
0
def main(unused_argv):
    val_losses = []
    assert FLAGS.output_dir, "--output_dir is required"
    # Create training directory.
    output_dir = FLAGS.output_dir
    if not tf.gfile.IsDirectory(output_dir):
        tf.gfile.MakeDirs(output_dir)

    dl = DataLoader(FLAGS.db_fname, mean=cfgs.mean_value, scale=cfgs.scale, n_vals=FLAGS.n_vals)
    dl.prepare()

    x_dim = dl.get_data_dim()
    model = Autoencoder(x_dim, cfgs, log_dir=FLAGS.log_dir)
    model.quantize_weights()

    txt_log_fname = FLAGS.log_dir + 'text_log.txt'
    log_fout = open(txt_log_fname, 'w')

    if FLAGS.pretrained_fname:
        try:
            log_train(log_fout, 'Resume from %s' %(FLAGS.pretrained_fname))
            model.restore(FLAGS.pretrained_fname)
        except:
            log_train(log_fout, 'Cannot restore from %s' %(FLAGS.pretrained_fname))
            pass
    
    lr = cfgs.initial_lr
    epoch_counter = 0
    ite = 0
    while True:
        start = time.time()
        x, flag = dl.next_batch(cfgs.batch_size, 'train')
        load_data_time = time.time() - start
        if flag: 
            epoch_counter += 1
        
        do_log = (ite % FLAGS.log_every_n_steps == 0) or flag
        do_snapshot = flag and epoch_counter > 0 and epoch_counter % FLAGS.save_every_n_epochs == 0
        val_loss = -1

        # train one step
        start = time.time()
        loss, _, summary, ite = model.partial_fit(x, lr, do_log)
        one_iter_time = time.time() - start
        
        # writing outs
        if do_log:
            log_train(log_fout, 'Iteration %d, (lr=%f) training loss  : %f' %(ite, lr, loss))
            if FLAGS.log_time:
                log_train(log_fout, 'Iteration %d, data loading: %f(s) ; one iteration: %f(s)' 
                    %(ite, load_data_time, one_iter_time))
            model.log(summary)
        if flag:
            val_loss = val(model, dl)
            val_losses.append(val_loss)
            log_train(log_fout, '----------------------------------------------------')
            if ite == 0:
                log_train(log_fout, 'Initial validation loss: %f' %(val_loss))
            else:
                log_train(log_fout, 'Epoch %d, validation loss: %f' %(epoch_counter, val_loss))
            log_train(log_fout, '----------------------------------------------------')
            model.log(summary)
        if do_snapshot:
            log_train(log_fout, 'Snapshotting')
            model.save(FLAGS.output_dir)
        
        if flag: 
            if cfgs.lr_update == 'val' and len(val_losses) >= 5 and val_loss >= max(val_losses[-5:-1]):
                    lr = lr * cfgs.lr_decay_factor
                    log_train(log_fout, 'Decay learning rate to %f' %lr)
            elif cfgs.lr_update == 'step' and epoch_counter % cfgs.num_epochs_per_decay == 0:
                    lr = lr * cfgs.lr_decay_factor
                    log_train(log_fout, 'Decay learning rate to %f' %lr)
            if epoch_counter == FLAGS.n_epochs:
                if not do_snapshot:
                    log_train(log_fout, 'Final snapshotting')
                    model.save(FLAGS.output_dir)
                break
    log_fout.close()
예제 #2
0
def main(unused_argv):
    assert FLAGS.output_dir, "--output_dir is required"
    # Create training directory.
    output_dir = FLAGS.output_dir
    if not tf.gfile.IsDirectory(output_dir):
        tf.gfile.MakeDirs(output_dir)

    dl = DataLoader(FLAGS.data_dir)
    dl.load_data()
    dl.split()

    x_dim = dl.get_X_dim()
    y_dim = dl.get_Y_dim()

    # Build the model.
    model = DMFD(x_dim,
                 y_dim,
                 dl.min_val,
                 dl.max_val,
                 cfgs,
                 log_dir=FLAGS.log_dir)

    if FLAGS.pretrained_fname:
        try:
            model.restore(FLAGS.pretrained_fname)
            print('Resume from %s' % (FLAGS.pretrained_fname))
        except:
            pass

    lr = cfgs.initial_lr
    epoch_counter = 0
    ite = 0
    lambda_ = cfgs.base_lambda_
    while True:
        start = time.time()
        x, y, R, mask, flag = dl.next_batch(cfgs.batch_size_x,
                                            cfgs.batch_size_y, 'train')
        load_data_time = time.time() - start
        if flag:
            epoch_counter += 1

        # some boolean variables
        do_log = (ite % FLAGS.log_every_n_steps == 0)
        do_snapshot = flag and epoch_counter > 0 and epoch_counter % FLAGS.save_every_n_epochs == 0
        val_loss = -1

        # train one step
        get_summary = do_log and cfgs.write_summary
        start = time.time()
        loss, _, summary, ite = model.partial_fit(x, y, R, mask, lr, lambda_,
                                                  get_summary)
        one_iter_time = time.time() - start

        # writing outs
        if do_log:
            print('Iteration %d, (lr=%f, lambda_=%f) training loss  : %f' %
                  (ite, lr, lambda_, loss))
            if FLAGS.log_time:
                print(
                    'Iteration %d, data loading: %f(s) ; one iteration: %f(s)'
                    % (ite, load_data_time, one_iter_time))
            if cfgs.write_summary:
                model.log(summary)

        if do_snapshot:
            print('Snapshotting')
            model.save(FLAGS.output_dir)

        if flag:
            lambda_ = get_lambda_(lambda_, cfgs.anneal_rate, epoch_counter - 1,
                                  cfgs.sigmoid_schedule)
            print('Finished epoch %d' % epoch_counter)
            print('--------------------------------------')
            if epoch_counter == FLAGS.n_epochs:
                if not do_snapshot:
                    print('Final snapshotting')
                    model.save(FLAGS.output_dir)
                break
            if epoch_counter % cfgs.num_epochs_per_decay == 0:
                lr = lr * cfgs.lr_decay_factor
                print('Decay learning rate to %f' % lr)