Пример #1
0
def train(M, src=None, trg=None, has_disc=True, saver=None, model_name=None):
    """Main training function

    Creates log file, manages datasets, trains model

    M          - (TensorDict) the model
    src        - (obj) source domain. Contains train/test Data obj
    trg        - (obj) target domain. Contains train/test Data obj
    has_disc   - (bool) whether model requires a discriminator update
    saver      - (Saver) saves models during training
    model_name - (str) name of the model being run with relevant parms info
    """
    # Training settings
    bs = 64
    iterep = 1000
    itersave = 20000
    n_epoch = 80
    epoch = 0
    feed_dict = {}

    # Create a log directory and FileWriter
    log_dir = os.path.join(args.logdir, model_name)
    delete_existing(log_dir)
    train_writer = tf.summary.FileWriter(log_dir)

    # Create a save directory
    if saver:
        model_dir = os.path.join('checkpoints', model_name)
        delete_existing(model_dir)
        os.makedirs(model_dir)

    # Replace src domain with psuedolabeled trg
    if args.dirt > 0:
        print("Setting backup and updating backup model")
        src = PseudoData(args.trg, trg, M.teacher)
        M.sess.run(M.update_teacher)

        # Sanity check model
        print_list = []
        if src:
            save_value(M.fn_ema_acc,
                       'test/src_test_ema_1k',
                       src.test,
                       train_writer,
                       0,
                       print_list,
                       full=False)

        if trg:
            save_value(M.fn_ema_acc, 'test/trg_test_ema', trg.test,
                       train_writer, 0, print_list)
            save_value(M.fn_ema_acc,
                       'test/trg_train_ema_1k',
                       trg.train,
                       train_writer,
                       0,
                       print_list,
                       full=False)

        print print_list

    if src: get_info(args.src, src)
    if trg: get_info(args.trg, trg)
    print("Batch size:", bs)
    print("Iterep:", iterep)
    print("Total iterations:", n_epoch * iterep)
    print("Log directory:", log_dir)

    for i in range(n_epoch * iterep):
        # Run discriminator optimizer
        if has_disc:
            update_dict(M, feed_dict, src, trg, bs)
            summary, _ = M.sess.run(M.ops_disc, feed_dict)
            train_writer.add_summary(summary, i + 1)

        # Run main optimizer
        update_dict(M, feed_dict, src, trg, bs)
        summary, _ = M.sess.run(M.ops_main, feed_dict)
        train_writer.add_summary(summary, i + 1)
        train_writer.flush()

        end_epoch, epoch = tb.utils.progbar(i,
                                            iterep,
                                            message='{}/{}'.format(epoch, i),
                                            display=args.run >= 999)

        # Update pseudolabeler
        if args.dirt and (i + 1) % args.dirt == 0:
            print "Updating teacher model"
            M.sess.run(M.update_teacher)

        # Log end-of-epoch values
        if end_epoch:
            print_list = M.sess.run(M.ops_print, feed_dict)

            if src:
                save_value(M.fn_ema_acc,
                           'test/src_test_ema_1k',
                           src.test,
                           train_writer,
                           i + 1,
                           print_list,
                           full=False)

            if trg:
                save_value(M.fn_ema_acc, 'test/trg_test_ema', trg.test,
                           train_writer, i + 1, print_list)
                save_value(M.fn_ema_acc,
                           'test/trg_train_ema_1k',
                           trg.train,
                           train_writer,
                           i + 1,
                           print_list,
                           full=False)

            print_list += ['epoch', epoch]
            print print_list

        if saver and (i + 1) % itersave == 0:
            save_model(saver, M, model_dir, i + 1)

    # Saving final model
    if saver:
        save_model(saver, M, model_dir, i + 1)
Пример #2
0
def train(L, FLAGS, saver=None, model_name=None):
    """
    :param L: (TensorDict) the model
    :param FLAGS: (FLAGS) contains experiment info
    :param saver: (Saver) saves models during training
    :param model_name: name of the model being run with relevant parms info
    :return: None
    """
    bs = FLAGS.bs
    lrD = FLAGS.lrD
    lrG = FLAGS.lrG
    lr = FLAGS.lr
    iterep = 1000
    itersave = 20000
    n_epoch = FLAGS.epoch
    epoch = 0
    feed_dict = {L.lrD: lrD, L.lrG: lrG, L.lr: lr}

    # Create a log directory and FileWriter
    log_dir = os.path.join(FLAGS.logdir, model_name)
    delete_existing(log_dir)
    train_writer = tf.summary.FileWriter(log_dir)

    # Create a save directory
    if saver:
        model_dir = os.path.join(FLAGS.ckptdir, model_name)
        delete_existing(model_dir)
        os.makedirs(model_dir)

    print(f"Batch size: {bs}")
    print(f"Iterep: {iterep}")
    print(f"Total iterations: {n_epoch * iterep}")
    print(f"Log directory: {log_dir}")
    print(f"Checkpoint directory: {model_dir}")

    if not FLAGS.phase:
        print(colored("LGAN Training started.", "blue"))

        for i in range(n_epoch * iterep):
            # for j in range(FLAGS.dg):
            # Train the discriminator
            update_dict(L, feed_dict, FLAGS)
            # summary, _ = L.sess.run(L.ops_disc, feed_dict)

            if FLAGS.clip:
                summary, _, _ = L.sess.run(L.ops_disc, feed_dict)
            else:
                summary, _ = L.sess.run(L.ops_disc, feed_dict)

            # if not FLAGS.wgan:
            #     summary, _ = L.sess.run(L.ops_disc, feed_dict)
            # else:
            #     summary, _, _ = L.sess.run(L.ops_disc, feed_dict)

            train_writer.add_summary(summary, i + 1)

            # Train the generator and the classifier
            update_dict(L, feed_dict, FLAGS)
            summary, _ = L.sess.run(L.ops_gen, feed_dict)
            train_writer.add_summary(summary, i + 1)
            train_writer.flush()

            end_epoch, epoch = tb.utils.progbar(i,
                                                iterep,
                                                message='{}/{}'.format(
                                                    epoch, i),
                                                display=True)

            # if not end_epoch:
            if end_epoch:
                summary = L.sess.run(L.ops_image, feed_dict)
                train_writer.add_summary(summary, i + 1)
                train_writer.flush()

                lrD *= FLAGS.lrDecay
                lrG *= FLAGS.lrDecay
                feed_dict.update({L.lrD: lrD, L.lrG: lrG})
                print_list = L.sess.run(L.ops_print, feed_dict)

                for j, item in enumerate(print_list):
                    if j % 2 == 0:
                        print_list[j] = item.decode("ascii")
                    else:
                        print_list[j] = round(item, 5)

                print_list += ['epoch', epoch]
                print(print_list)

            if saver and (i + 1) % itersave == 0:
                save_model(saver, L, model_dir, i + 1)

        # Saving final model
        if saver:
            save_model(saver, L, model_dir, i + 1)
        print(colored("LGAN Training ended.", "blue"))

    else:
        print(colored("LSTM Training started.", "blue"))
        min_val_mae = 10
        # update_dict(L, feed_dict, FLAGS)
        #
        # test1 = L.sess.run(L.test1, feed_dict)
        # # print(test1.shape)
        # # print(test[0, 0])
        # # print(test[1, 0])
        # plt.imshow(test1[0, 0])
        # plt.show()
        # plt.imshow(test1[1, 0])
        # plt.show()
        for i in range(n_epoch * iterep):
            update_dict(L, feed_dict, FLAGS)
            summary, _ = L.sess.run(L.ops_lstm, feed_dict)
            train_writer.add_summary(summary, i + 1)
            train_writer.flush()

            end_epoch, epoch = tb.utils.progbar(i,
                                                iterep,
                                                message='{}/{}'.format(
                                                    epoch, i),
                                                display=True)

            if end_epoch:
                # print("!")
                # summary = L.sess.run(L.ops_lstm_image, feed_dict)
                # train_writer.add_summary(summary, i + 1)
                # train_writer.flush()
                # print("!!")
                val_mae = 0
                for j in range(50):
                    val_seq_in, val_seq_out = get_val_batch(
                        j * 10, (j + 1) * 10)
                    val_seq_in = np.transpose(val_seq_in, [1, 0, 2, 3, 4])
                    val_seq_out = np.transpose(val_seq_out, [1, 0, 2, 3, 4])
                    feed_dict.update({
                        L.val_seq_in: val_seq_in,
                        L.val_seq_out: val_seq_out
                    })
                    current_val_mae = L.sess.run(L.val_mae, feed_dict)
                    val_mae += current_val_mae
                val_mae = val_mae / 50.0 * 255.0
                print_list = L.sess.run(L.ops_lstm_print, feed_dict)

                for j, item in enumerate(print_list):
                    if j % 2 == 0:
                        print_list[j] = item.decode("ascii")
                    else:
                        print_list[j] = round(item, 5)

                print_list += ['val_mae', val_mae]
                print_list += ['epoch', epoch]
                print(print_list)

                if False and val_mae < min_val_mae:
                    min_val_mae = val_mae
                    for j in range(50):
                        test_seq_in = get_test_batch(j * 10, (j + 1) * 10)
                        test_seq_in = np.transpose(test_seq_in,
                                                   [1, 0, 2, 3, 4])
                        feed_dict.update({L.test_seq_in: test_seq_in})
                        test_seq_out_pred = L.sess.run(L.test_seq_out_pred,
                                                       feed_dict)

                        test_seq_out_pred[test_seq_out_pred < 0] = 0
                        test_seq_out_pred = (test_seq_out_pred * 255).astype(
                            np.uint8)

                        for k in range(10):
                            idx = j * 10 + k
                            for t in range(10):
                                if not os.path.exists('test_predicted'):
                                    os.mkdir('test_predicted')
                                folder_name = os.path.join(
                                    'test_predicted/', 'sequence%03d' % idx)
                                if not os.path.exists(folder_name):
                                    os.mkdir(folder_name)
                                img_path = os.path.join(
                                    folder_name, 'frames%02d.png' % t)
                                this_img = Image.fromarray(
                                    test_seq_out_pred[t, k])
                                this_img.save(img_path)

                lr *= FLAGS.lrDecay
                feed_dict.update({L.lr: lr})

            if saver and (i + 1) % itersave == 0:
                save_model(saver, L, model_dir, i + 1)

        # Saving final model
        if saver:
            save_model(saver, L, model_dir, i + 1)
        print(colored("LSTM Training ended.", "blue"))
Пример #3
0
def train(M, src=None, trg=None, saver=None, model_name=None):
    """Main training function

    Creates log file, manages datasets, trains model

    M          - (TensorDict) the model
    src        - (obj) source domain. Contains train/test Data obj
    trg        - (obj) target domain. Contains train/test Data obj
    saver      - (Saver) saves models during training
    model_name - (str) name of the model being run with relevant parms info
    """
    # Training settings
    bs = 64
    iterep = 1000
    n_epoch = 80
    epoch = 0
    feed_dict = {}

    # Create a log directory and FileWriter
    log_dir = os.path.join(args.logdir, model_name)
    delete_existing(log_dir)
    train_writer = tf.summary.FileWriter(log_dir)

    # Create a save directory
    if saver:
        model_dir = os.path.join('checkpoints', model_name)
        delete_existing(model_dir)
        os.makedirs(model_dir)

    if src: get_info('Source mnist', src)
    if trg: get_info('Target svhn', trg)
    print "Batch size:", bs
    print "Iterep:", iterep
    print "Total iterations:", n_epoch * iterep
    print "Log directory:", log_dir

    for i in xrange(n_epoch * iterep):
        # Run main optimizer
        update_dict(M, feed_dict, src, trg, bs)
        summary, _ = M.sess.run(M.ops_main, feed_dict)
        train_writer.add_summary(summary, i + 1)
        train_writer.flush()

        end_epoch, epoch = tb.utils.progbar(i,
                                            iterep,
                                            message='{}/{}'.format(epoch, i),
                                            display=args.run >= 999)

        # Log end-of-epoch values
        if end_epoch:
            print_list = M.sess.run(M.ops_print, feed_dict)

            if src:
                save_acc(M,
                         'fn_ema_acc',
                         'test/src_test_ema_1k',
                         src.test,
                         train_writer,
                         i + 1,
                         print_list,
                         full=False)

            if trg:
                save_acc(M, 'fn_ema_acc', 'test/trg_test_ema', trg.test,
                         train_writer, i + 1, print_list)
                save_acc(M,
                         'fn_ema_acc',
                         'test/trg_train_ema_1k',
                         trg.train,
                         train_writer,
                         i + 1,
                         print_list,
                         full=False)

            print_list += ['epoch', epoch]
            print print_list

        if saver and (i + 1) % 20000 == 0:
            save_model(saver, M, model_dir, i + 1)

    # Saving final model
    if saver:
        save_model(saver, M, model_dir, i + 1)
Пример #4
0
def train(M, src=None, trg=None, has_disc=True, saver=None, model_name=None):
    """Main training function

    Creates log file, manages datasets, trains model

    M          - (TensorDict) the model
    src        - (obj) source domain. Contains train/test Data obj
    trg        - (obj) target domain. Contains train/test Data obj
    has_disc   - (bool) whether model requires a discriminator update
    saver      - (Saver) saves models during training
    model_name - (str) name of the model being run with relevant parms info
    """
    # Training settings
    bs = 64
    iterep = 1000
    itersave = 20000
    n_epoch = 80
    epoch = 0
    feed_dict = {}

    # Create a log directory and FileWriter
    log_dir = os.path.join(args.logdir, model_name)
    delete_existing(log_dir)
    train_writer = tf.summary.FileWriter(log_dir)

    # Create a save directory
    if saver:
        model_dir = os.path.join('checkpoints', model_name)
        delete_existing(model_dir)
        os.makedirs(model_dir)

    # Replace src domain with psuedolabeled trg
    if args.dirt > 0:
        print "Setting backup and updating backup model"
        src = PseudoData(args.trg, trg, M.teacher)
        M.sess.run(M.update_teacher)

        # Sanity check model
        print_list = []
        if src:
            save_value(M.fn_ema_acc, 'test/src_test_ema_1k',
                     src.test,  train_writer, 0, print_list, full=False)

        if trg:
            save_value(M.fn_ema_acc, 'test/trg_test_ema',
                     trg.test,  train_writer, 0, print_list)
            save_value(M.fn_ema_acc, 'test/trg_train_ema_1k',
                     trg.train, train_writer, 0, print_list, full=False)

        print print_list

    if src: get_info(args.src, src)
    if trg: get_info(args.trg, trg)
    print "Batch size:", bs
    print "Iterep:", iterep
    print "Total iterations:", n_epoch * iterep
    print "Log directory:", log_dir

    for i in xrange(n_epoch * iterep):
        # Run discriminator optimizer
        if has_disc:
            update_dict(M, feed_dict, src, trg, bs)
            summary, _ = M.sess.run(M.ops_disc, feed_dict)
            train_writer.add_summary(summary, i + 1)

        # Run main optimizer
        update_dict(M, feed_dict, src, trg, bs)
        summary, _ = M.sess.run(M.ops_main, feed_dict)
        train_writer.add_summary(summary, i + 1)
        train_writer.flush()

        end_epoch, epoch = tb.utils.progbar(i, iterep,
                                            message='{}/{}'.format(epoch, i),
                                            display=args.run >= 999)

        # Update pseudolabeler
        if args.dirt and (i + 1) % args.dirt == 0:
            print "Updating teacher model"
            M.sess.run(M.update_teacher)

        # Log end-of-epoch values
        if end_epoch:
            print_list = M.sess.run(M.ops_print, feed_dict)

            if src:
                save_value(M.fn_ema_acc, 'test/src_test_ema_1k',
                         src.test,  train_writer, i + 1, print_list, full=False)

            if trg:
                save_value(M.fn_ema_acc, 'test/trg_test_ema',
                         trg.test,  train_writer, i + 1, print_list)
                save_value(M.fn_ema_acc, 'test/trg_train_ema_1k',
                         trg.train, train_writer, i + 1, print_list, full=False)

            print_list += ['epoch', epoch]
            print print_list

        if saver and (i + 1) % itersave == 0:
            save_model(saver, M, model_dir, i + 1)

    # Saving final model
    if saver:
        save_model(saver, M, model_dir, i + 1)
Пример #5
0
def train(M,
          src=None,
          trg=None,
          has_disc=True,
          add_z=False,
          saver=None,
          model_name=None,
          y_prior=None):
    log_dir = os.path.join(args.logdir, model_name)
    delete_existing(log_dir)
    train_writer = tf.summary.FileWriter(log_dir)

    if saver:
        model_dir = os.path.join('checkpoints', model_name)
        delete_existing(model_dir)
        os.makedirs(model_dir)

    bs = 64
    iterep = 1000
    n_epoch = 200
    epoch = 0
    feed_dict = {M.phase: 1}
    n_viz = 1 if args.run >= 999 else 5

    if src: print "Src size:", src.train.images.shape
    if trg: print "Trg size:", trg.train.images.shape
    print "Viz per # epoch:", n_viz
    print "Batch size:", bs
    print "Iterep:", iterep
    print "Total iterations:", n_epoch * iterep
    print "Log directory:", log_dir

    for i in xrange(n_epoch * iterep):
        tau = np.maximum(np.exp(-0.00003 * i), 0.5)

        # Discriminator
        if has_disc:
            update_dict(M, feed_dict, src, trg, add_z, bs, y_prior, tau)
            summary, _ = M.sess.run(M.ops_disc, feed_dict)
            train_writer.add_summary(summary, i + 1)

        # Main
        update_dict(M, feed_dict, src, trg, add_z, bs, y_prior, tau)
        summary, _ = M.sess.run(M.ops_main, feed_dict)
        train_writer.add_summary(summary, i + 1)

        train_writer.flush()
        end_epoch, epoch = tb.utils.progbar(i,
                                            iterep,
                                            message='{}/{}'.format(epoch, i),
                                            display=args.run >= 999)

        if end_epoch:
            print_list = M.sess.run(M.ops_print, feed_dict)

            if src:
                save_acc(M, 'fn_test_acc', 'test/src_test', src.test,
                         train_writer, i, print_list)
                save_acc(M, 'fn_ema_acc', 'test/src_ema', src.test,
                         train_writer, i, print_list)

            print_list += ['epoch', epoch]
            print print_list
            sys.stdout.flush()

        if end_epoch and epoch % n_viz == 0:
            if hasattr(M, 'ops_image'):
                summary = M.sess.run(M.ops_image, feed_dict)
                train_writer.add_summary(summary, i + 1)

        if saver and (i + 1) % 20000 == 0:
            path = saver.save(M.sess,
                              os.path.join(model_dir, 'model'),
                              global_step=i + 1)
            print "Saving model to {:s}".format(path)
            sys.stdout.flush()

    # Saving final model just in case
    if saver:
        path = saver.save(M.sess,
                          os.path.join(model_dir, 'model'),
                          global_step=i + 1)
        print "Saving model to {:s}".format(path)
        sys.stdout.flush()
Пример #6
0
def train(M, src=None, trg=None, has_disc=True, saver=None, model_name=None):
    """Main training function

    Creates log file, manages datasets, trains model

    M          - (TensorDict) the model
    src        - (obj) source domain. Contains train/test Data obj
    trg        - (obj) target domain. Contains train/test Data obj
    has_disc   - (bool) whether model requires a discriminator update
    saver      - (Saver) saves models during training
    model_name - (str) name of the model being run with relevant parms info
    """
    # Training settings
    iterep = 1000
    itersave = 20000
    n_epoch = 200
    epoch = 0
    feed_dict = {}

    # Create a log directory and FileWriter
    log_dir = os.path.join(args.logdir, model_name)
    delete_existing(log_dir)
    train_writer = tf.summary.FileWriter(log_dir)

    # Create a directory to save generated images
    gen_img_path = os.path.join(args.gendir, model_name)
    delete_existing(gen_img_path)
    os.makedirs(gen_img_path)

    # Create a save directory
    if saver:
        model_dir = os.path.join('checkpoints', model_name)
        delete_existing(model_dir)
        os.makedirs(model_dir)

    # Replace src domain with psuedolabeled trg
    if args.dirt > 0:
        print "Setting backup and updating backup model"
        src = PseudoData(args.trg, trg, M.teacher)
        M.sess.run(M.update_teacher)

        # Sanity check model
        print_list = []
        if src:
            save_value(M.fn_ema_acc,
                       'test/src_test_ema_1k',
                       src.test,
                       train_writer,
                       0,
                       print_list,
                       full=False)

        if trg:
            save_value(M.fn_ema_acc, 'test/trg_test_ema', trg.test,
                       train_writer, 0, print_list)
            save_value(M.fn_ema_acc,
                       'test/trg_train_ema_1k',
                       trg.train,
                       train_writer,
                       0,
                       print_list,
                       full=False)

        print print_list

    if src: get_info(args.src, src)
    if trg: get_info(args.trg, trg)
    print "Batch size:", args.bs
    print "Iterep:", iterep
    print "Total iterations:", n_epoch * iterep
    print "Log directory:", log_dir

    best_acc = -1.
    trg_acc = -1.
    for i in xrange(n_epoch * iterep):
        if has_disc:
            # Run discriminator optimizer
            update_dict(M, feed_dict, src, trg, args.bs)
            summary, _ = M.sess.run(M.ops_disc, feed_dict)
            train_writer.add_summary(summary, i + 1)

            # Run generator optimizer
            update_dict(M, feed_dict, None, trg, args.bs, noise=True)
            summary, _ = M.sess.run(M.ops_gen, feed_dict)
            train_writer.add_summary(summary, i + 1)

        # Run main optimizer
        update_dict(M, feed_dict, src, trg, args.bs, noise=True)
        summary, _ = M.sess.run(M.ops_main, feed_dict)
        train_writer.add_summary(summary, i + 1)
        train_writer.flush()

        end_epoch, epoch = tb.utils.progbar(i,
                                            iterep,
                                            message='{}/{}'.format(epoch, i),
                                            display=args.run >= 999)

        # Update pseudolabeler
        if args.dirt and (i + 1) % args.dirt == 0:
            print "Updating teacher model"
            M.sess.run(M.update_teacher)

        if (i + 1) % iterep == 0:
            gen_imgs = M.sess.run(M.trg_gen_x, feed_dict)
            manifold_h = int(np.floor(np.sqrt(args.bs)))
            manifold_w = int(np.floor(np.sqrt(args.bs)))
            visualize_results(
                gen_imgs, [manifold_h, manifold_w],
                os.path.join(gen_img_path, 'epoch_{}.png'.format(
                    (i + 1) / iterep)))

        # Log end-of-epoch values
        if end_epoch:
            print_list = M.sess.run(M.ops_print, feed_dict)

            if src:
                save_value(M.fn_ema_acc,
                           'test/src_test_ema_1k',
                           src.test,
                           train_writer,
                           i + 1,
                           print_list,
                           full=False)

            if trg:
                trg_acc = save_value(M.fn_ema_acc, 'test/trg_test_ema',
                                     trg.test, train_writer, i + 1, print_list)
                save_value(M.fn_ema_acc,
                           'test/trg_train_ema_1k',
                           trg.train,
                           train_writer,
                           i + 1,
                           print_list,
                           full=False)

            print_list += ['epoch', epoch]
            print print_list

        if saver and trg_acc > best_acc:
            print("Saving new best model")
            saver.save(M.sess, os.path.join(model_dir, 'model_best'))
            best_acc = trg_acc

    # Saving final model
    if saver:
        save_model(saver, M, model_dir, i + 1)
Пример #7
0
def train(M, FLAGS, saver=None, model_name=None):
    print(colored("Training is started.", "blue"))

    iterep = 1000
    itersave = 20000
    n_epoch = FLAGS.epoch
    epoch = 0
    feed_dict = {}

    # Create log directory
    log_dir = os.path.join(FLAGS.logdir, model_name)
    delete_existing(log_dir)
    train_writer = tf.summary.FileWriter(log_dir)

    # Create checkpoint directory
    if saver:
        model_dir = os.path.join(FLAGS.ckptdir, model_name)
        delete_existing(model_dir)
        os.makedirs(model_dir)

    session_file_index = random.randint(0, 49)
    session_vector_file_name = 'session_vector_' + str(session_file_index)
    session_vector_file = load_file(session_vector_file_name)

    # test_session_file_index = session_file_index
    # test_session_vector_file = session_vector_file
    test_session_file_index = random.randint(50, 99)
    test_session_vector_file_name = 'session_vector_' + str(
        test_session_file_index)
    test_session_vector_file = load_file(test_session_vector_file_name)
    test_session_vector = np.zeros((FLAGS.bs, 369539))
    test_timestamp_matrix = np.zeros((FLAGS.bs))

    timestamp_dict = load_file('timestamp_dict')

    print(f"Iterep: {iterep}")
    print(f"Total iterations: {n_epoch * iterep}")

    for i in range(n_epoch * iterep):
        update_dict(M, feed_dict, FLAGS.bs, session_file_index,
                    session_vector_file, timestamp_dict)
        summary, _ = M.sess.run(M.ops, feed_dict)
        train_writer.add_summary(summary, i + 1)
        train_writer.flush()

        end_epoch, epoch = tb.utils.progbar(i,
                                            iterep,
                                            message='{}/{}'.format(epoch, i),
                                            display=True)

        if end_epoch:

            # session_file_index = random.randint(0, 49)
            # session_vector_file_name = 'session_vector_' + str(session_file_index)
            # session_vector_file = load_file(session_vector_file_name)

            # test_session_file_index = random.randint(50, 99)
            # test_session_vector_file_name = 'session_vector_' + str(test_session_file_index)
            # test_session_vector_file = load_file((test_session_vector_file_name))

            # test_session_vector = np.zeros((FLAGS.bs, 369539))
            # test_timestamp_matrix = np.zeros((FLAGS.bs))
            for j in range(FLAGS.bs):
                while True:
                    test_vector_index = random.randint(0, 909)
                    test_session_index = 910 * test_session_file_index + test_vector_index
                    test_session_vector[j] = test_session_vector_file[
                        test_vector_index]
                    test_timestamp_matrix[j] = timestamp_dict[
                        test_session_index]
                    if test_timestamp_matrix[j] != 0: break

            # test_timestamp_matrix += 1

            feed_dict.update({
                M.test_sv: test_session_vector,
                M.test_ts: test_timestamp_matrix
            })

            print_list = M.sess.run(M.ops_print, feed_dict)

            for j, item in enumerate(print_list):
                if j % 2 == 0:
                    print_list[j] = item.decode("ascii")
                # else:
                # print_list[j] = round(item, 5)
                # print_list[j] = np.around(item, 5)

            print_list += ['epoch', epoch]
            print(print_list)

        if saver and (i + 1) % itersave == 0:
            save_model(saver, M, model_dir, i + 1)

    # Saving final model
    if saver:
        save_model(saver, M, model_dir, i + 1)

    print(colored("Training ended.", "blue"))
Пример #8
0
def train(M, src=None, trg=None, has_disc=True, add_z=False,
          saver=None, model_name=None, y_prior=None):
    log_dir = os.path.join(args.logdir, model_name)
    delete_existing(log_dir)
    train_writer = tf.summary.FileWriter(log_dir)

    if saver:
        model_dir = os.path.join('checkpoints', model_name)
        delete_existing(model_dir)
        os.makedirs(model_dir)

    backup = PseudoData(args.trg, trg, M)
    if args.phase == 1:
        print "Setting backup and updating backup model"
        src = backup
        M.sess.run(M.back_update)

        if args.init:
            print "Re-initialize student model"
            M.sess.run(M.init_update)

        print_list = []
        if src:
            save_acc(M, 'fn_ema_acc', 'test/src_test_ema_1k',
                     src.test,  train_writer, -1, print_list, full=False)

        if trg:
            save_acc(M, 'fn_ema_acc', 'test/trg_test_ema',
                     trg.test,  train_writer, -1, print_list)
            save_acc(M, 'fn_ema_acc', 'test/trg_train_ema_1k',
                     trg.train, train_writer, -1, print_list, full=False)

        print print_list

#    bs = 16
    bs = 64
    iterep = 1000
    n_epoch = 80
#    n_epoch = 20
    epoch = 0
    feed_dict = {M.phase: 1}

    if src: print "Src size:", src.train.images.shape
    if trg: print "Trg size:", trg.train.images.shape
    print "Batch size:", bs
    print "Iterep:", iterep
    print "Total iterations:", n_epoch * iterep
    print "Log directory:", log_dir

    for i in xrange(n_epoch * iterep):
        # Discriminator
        if has_disc:
            update_dict(M, feed_dict, src, trg, add_z, bs, y_prior)
            summary, _ = M.sess.run(M.ops_disc, feed_dict)
            train_writer.add_summary(summary, i + 1)

        # Main
        update_dict(M, feed_dict, src, trg, add_z, bs, y_prior)

        summary, _ = M.sess.run(M.ops_main, feed_dict)
        train_writer.add_summary(summary, i + 1)
        train_writer.flush()

        end_epoch, epoch = tb.utils.progbar(i, iterep,
                                            message='{}/{}'.format(epoch, i),
                                            display=1)

        if args.dirt and (i + 1) >= args.pivot and (i + 1) % args.dirt == 0:
            print "Setting backup and updating backup model"
            src = backup
            M.sess.run(M.back_update)

            if args.init:
                print "Re-initialize student model"
                M.sess.run(M.init_update)

        if end_epoch:
            print_list = M.sess.run(M.ops_print, feed_dict)

            if src:
                save_acc(M, 'fn_ema_acc', 'test/src_test_ema_1k',
                         src.test,  train_writer, i, print_list, full=False)

            if trg:
                save_acc(M, 'fn_ema_acc', 'test/trg_test_ema',
                         trg.test,  train_writer, i, print_list)
                save_acc(M, 'fn_ema_acc', 'test/trg_train_ema_1k',
                         trg.train, train_writer, i, print_list, full=False)
                ###
#                config2 = projector.ProjectorConfig()
#                embed = config2.embeddings.add()
#                embed.tensor_name = 'embedding:0'
#                projector.visualize_embeddings(train_writer, config2)

            print_list += ['epoch', epoch]
            print print_list
            sys.stdout.flush()

        if end_epoch and epoch % 20 == 0:
            if hasattr(M, 'ops_image'):
                summary = M.sess.run(M.ops_image, feed_dict)
                train_writer.add_summary(summary, i + 1)

        if saver and (i + 1) % 5000 == 0:
            path = saver.save(M.sess,
                              os.path.join(model_dir, 'model'),
                              global_step=i + 1)
            print "Saving model to {:s}".format(path)
            sys.stdout.flush()

    # Saving final model just in case
    if saver:
        path = saver.save(M.sess,
                          os.path.join(model_dir, 'model'),
                          global_step=i + 1)
        print "Saving model to {:s}".format(path)
        sys.stdout.flush()