示例#1
0
def main(argv):
    tf.set_random_seed(FLAGS.seed)
    if not tf.gfile.IsDirectory(FLAGS.exp_dir):
        tf.gfile.MakeDirs(FLAGS.exp_dir)

    print('Made directory')
    train_set, val_set, _ = mdu.read_dataset(FLAGS.dataset)
    molecule_mapping = mdu.read_molecule_mapping_for_set(FLAGS.dataset)
    inv_mol_mapping = {v: k for k, v in enumerate(molecule_mapping)}

    if FLAGS.dataset.startswith('zinc'):
        bond_mapping = mdu.read_bond_mapping_for_set(FLAGS.dataset)
        inv_bond_mapping = {('%d_%d' % v): k
                            for k, v in enumerate(bond_mapping)}
        stereo = True
    else:
        bond_mapping = None
        inv_bond_mapping = None
        stereo = False

    # unique set of training data, used for evaluation
    train_set_unique = set(train_set)
    train_set, val_set, _ = mdu.read_molecule_graphs_set(FLAGS.dataset)

    n_node_types = len(molecule_mapping)
    print(n_node_types)
    n_edge_types = mdu.get_max_edge_type(train_set) + 1
    max_n_nodes = max(len(m.atoms) for m in train_set)

    train_set = mdu.Dataset(train_set, FLAGS.batch_size, shuffle=True)
    val_set = mdu.Dataset(val_set, FLAGS.batch_size, shuffle=True)

    # n_node_types: number of node types (assumed categorical)
    # n_edge_types: number of edge types/ labels
    model_hparams = hparams.get_hparams_ChEMBL()
    print('Number of node/edge types: ', n_node_types, n_edge_types)
    print('Inside train function now...')
    with tf.device('/gpu:1'):
        if FLAGS.sample:
            sample(model_hparams,
                   train_set,
                   val_set,
                   eval_every=FLAGS.eval_every,
                   exp_dir=FLAGS.exp_dir,
                   summary_writer=None,
                   n_node_types=n_node_types,
                   n_edge_types=n_edge_types)
        else:
            train(model_hparams,
                  train_set,
                  val_set,
                  eval_every=FLAGS.eval_every,
                  exp_dir=FLAGS.exp_dir,
                  summary_writer=None,
                  n_node_types=n_node_types,
                  n_edge_types=n_edge_types)
示例#2
0
                                              str(args.test_num_checkpoint))
        test(model, device, test_out_path)
        print('test output saved to ', test_out_path)

    # 6) train & valid
    else:
        # +) valid dataset
        #print('========== valid dataset ==========')
        #valid_dataset = data_utils.Dataset(
        #        version=args.version, data='valid', size=args.win_size, feature=args.feature, dims=args.feature_dims)
        #valid_loader = DataLoader(valid_dataset, batch_size=args.valid_batch_size, shuffle=False, num_workers=8)
        # +) train dataset
        print('========== train dataset ==========')
        train_dataset = data_utils.Dataset(version=args.version,
                                           data='train',
                                           size=args.win_size,
                                           feature=args.feature,
                                           dims=args.feature_dims)
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.train_batch_size,
                                  shuffle=True,
                                  num_workers=8)
        print('========== train process ==========')
        # +) model init (check resume mode)
        if args.resume_mode:
            model = BaselineSAMAFRN(args.embedding_size,
                                    args.speakers).to(device)
            summary(model, input_size=(1200, 64))
            resume_checkpoint_path = '{}/epoch_{}.pth'.format(
                model_save_path, str(args.resume_num_checkpoint))
            model.load_state_dict(torch.load(resume_checkpoint_path))
示例#3
0
def main(unused_argv):
    del unused_argv  # Unused

    # Currently implemented for only one host
    assert (FLAGS.num_hosts == 1)

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    # Get corpus info
    corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
    n_token = corpus_info["vocab_size"]
    cutoffs = corpus_info["cutoffs"][1:-1]
    tf.compat.v1.logging.info("n_token {}".format(n_token))

    if FLAGS.do_train:
        # Get train input function
        train_data = data_utils.Dataset(
            data_dir=FLAGS.data_dir,
            record_info_dir=FLAGS.record_info_dir,
            split="train",
            per_host_bsz=FLAGS.train_batch_size // FLAGS.num_hosts,
            tgt_len=FLAGS.tgt_len,
            num_core_per_host=FLAGS.num_core_per_host,
            num_hosts=FLAGS.num_hosts)

    if FLAGS.do_eval or FLAGS.do_test:
        # Get valid input function
        valid_data = data_utils.Dataset(
            data_dir=FLAGS.data_dir,
            record_info_dir=FLAGS.record_info_dir,
            split="valid",
            per_host_bsz=FLAGS.eval_batch_size // FLAGS.num_hosts,
            tgt_len=FLAGS.tgt_len,
            num_core_per_host=FLAGS.num_core_per_host,
            num_hosts=FLAGS.num_hosts)

        test_data = data_utils.Dataset(
            data_dir=FLAGS.data_dir,
            record_info_dir=FLAGS.record_info_dir,
            split="test",
            per_host_bsz=FLAGS.test_batch_size // FLAGS.num_hosts,
            tgt_len=FLAGS.test_tgt_len,
            num_core_per_host=FLAGS.num_core_per_host,
            num_hosts=FLAGS.num_hosts)
    else:
        valid_data = None
        test_data = None

    if FLAGS.use_tpu:
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.experimental.TPUStrategy(resolver)
    else:
        strategy = tf.distribute.get_strategy()
    print("Number of accelerators: ", strategy.num_replicas_in_sync)

    # Ensure that number of replicas in sync is same as 'FLAGS.num_core_per_host'
    assert (FLAGS.num_core_per_host == strategy.num_replicas_in_sync)

    chk_name = 'texl_chk'
    if FLAGS.do_train:
        train(n_token, cutoffs, train_data, valid_data, test_data, strategy,
              chk_name)
    if FLAGS.do_test:
        evaluate(n_token, cutoffs, valid_data, test_data, strategy, chk_name)
    print("\nLoading data...")
    data = load_wine()

    print("\nPreprocessing data...")
    x = data['data']
    y = data['target']
    y = data_utils.to_one_hot(y)
    x_train, y_train, x_test, y_test, x_valid, y_valid = data_utils.train_test_valid_split(
        x, y, 0.7, 0.2, should_shuffle=True)

    x_train = scaler.fit_transform(x_train)
    x_test = scaler.transform(x_test)
    x_valid = scaler.transform(x_valid)

    dataset = data_utils.Dataset(x_train, y_train)
    data_manager = data_utils.DataManager(dataset, batch_size=8)

    print("\nBuilding network...")
    n_inputs = len(x[0])
    n_hidden = 2
    n_outputs = 3

    network = MLP([n_inputs, n_hidden, n_outputs],
                  initialiser=np.random.rand,
                  learning_rate=0.01)

    print("\nTraining network...")
    train(network, data_manager, x_valid, y_valid, n_epochs=250)

    print("\nTesting trained network...")
示例#5
0
    # 4) use cuda
    if torch.cuda.is_available():
        device = 'cuda'
        print('device is ', device)
    else:
        device = 'cpu'
        print('device is ', device)

    # 5) eval
    if args.eval_mode:
        # +) eval dataset
        print('========== eval dataset ==========')
        eval_dataset = data_utils.Dataset(track=args.track,
                                          data='eval',
                                          size=args.input_size,
                                          feature=args.feature,
                                          tag=args.data_tag)
        eval_loader = DataLoader(eval_dataset,
                                 batch_size=args.eval_batch_size,
                                 shuffle=False,
                                 num_workers=8)
        # +) load model
        print('========== eval process ==========')
        model = SiameseNetwork(args.embedding_size).to(device)
        eval_checkpoint_path = '{}/epoch_{}.pth'.format(
            model_save_path, str(args.eval_num_checkpoint))
        model.load_state_dict(torch.load(eval_checkpoint_path))
        print('model loaded from ', eval_checkpoint_path)
        # +) eval
        eval_output_path = '{}/{}.result'.format(model_save_path,
示例#6
0
文件: train.py 项目: chychen/iic
def train():
    with tf.Graph().as_default() as graph, tf.device('/cpu:0'):
        # Create a variable to count the number of train() calls. This equals the
        # number of batches processed * FLAGS.num_gpus.
        global_step = tf.train.create_global_step(graph=graph)
        optimizer = tf.train.AdamOptimizer(FLAGS.lr)
        # batch_images, batch_labels = data_utils.get_inputs(
        #     FLAGS.train_data_path, FLAGS.batch_size//FLAGS.num_gpus)
        dataset = data_utils.Dataset(
            FLAGS.train_data_path,
            FLAGS.validation_train_data_path,
            FLAGS.validation_data_path,
            FLAGS.batch_size // FLAGS.num_gpus,
            buffer_size=FLAGS.buffer_size,
            num_threads=FLAGS.num_threads)
        batch_images, batch_labels = dataset.get_next()
        tf.summary.image(
            'train images', batch_images, collections=LOG_COLLECTIONS)
        # Calculate the gradients for each model tower.
        tower_grads = []
        is_training = tf.placeholder(tf.bool)
        with tf.variable_scope(tf.get_variable_scope()):
            for i in range(FLAGS.num_gpus):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('tower_%d' % (i)) as scope:
                        loss = tower_loss(
                            scope, batch_images, batch_labels, is_training,
                            dataset.get_human_readable_to_label())
                        # Reuse variables for the next tower.
                        tf.get_variable_scope().reuse_variables()
                        # Calculate the gradients for the batch of data on this CIFAR tower.
                        grads = optimizer.compute_gradients(loss)
                        # Keep track of the gradients across all towers.
                        tower_grads.append(grads)
        # summary
        summaries = tf.get_collection('train')
        vtrain_summaries = tf.get_collection('validation_train')
        vtest_summaries = tf.get_collection('validation_test')
        # Add a summary to track the learning rate.
        summaries.append(tf.summary.scalar('learning_rate', FLAGS.lr))
        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        grads = average_gradients(tower_grads)
        # Add histograms for gradients.
        for grad, var in grads:
            if grad is not None:
                summaries.append(
                    tf.summary.histogram(var.op.name + '/gradients', grad))
        # Apply the gradients to adjust the shared variables.
        apply_gradient_op = optimizer.apply_gradients(
            grads, global_step=global_step)
        # Add histograms for trainable variables.
        for var in tf.trainable_variables():
            summaries.append(tf.summary.histogram(var.op.name, var))
        # Track the moving averages of all trainable variables.
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.MOVING_AVERAGE_DECAY, global_step)
        variables_averages_op = variable_averages.apply(
            tf.trainable_variables())
        # Group all updates to into a single train op.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = tf.group(apply_gradient_op, variables_averages_op)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
        summary_op = tf.summary.merge(summaries)
        vtrain_summary_op = tf.summary.merge(vtrain_summaries)
        vtest_summary_op = tf.summary.merge(vtest_summaries)
        init = tf.global_variables_initializer()
        num_batch_per_epoch = int(1.7e6 // FLAGS.batch_size)
        with tf.Session(
                config=tf.ConfigProto(
                    allow_soft_placement=True,
                    log_device_placement=FLAGS.log_device_placement)) as sess:
            sess.run(init)
            if FLAGS.restore_path is not None:
                saver.restore(sess, FLAGS.restore_path)
                print('successfully restore model from checkpoint: %s' %
                      (FLAGS.restore_path))
            train_handle, vtrain_handle, vtest_handle = sess.run([
                dataset.train_iterator.string_handle(),
                dataset.vtrain_iterator.string_handle(),
                dataset.vtest_iterator.string_handle()
            ])
            sess.run([
                dataset.train_iterator.initializer,
                dataset.vtrain_iterator.initializer,
                dataset.vtest_iterator.initializer
            ])
            # Create a coordinator and run all QueueRunner objects
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            summary_writer = tf.summary.FileWriter(
                os.path.join(FLAGS.train_dir, 'train'), sess.graph)
            vtrain_summary_writer = tf.summary.FileWriter(
                os.path.join(FLAGS.train_dir, 'validation_train'))
            vtest_summary_writer = tf.summary.FileWriter(
                os.path.join(FLAGS.train_dir, 'validation_test'))
            batch_idx = 0
            while True:
                start_time = time.time()
                _, loss_value, global_step_v, summary_str = sess.run(
                    [train_op, loss, global_step, summary_op],
                    feed_dict={
                        dataset.handle: train_handle,
                        is_training: True
                    })
                batch_idx = global_step_v // FLAGS.num_gpus
                duration = time.time() - start_time
                if global_step_v % (
                        100 * FLAGS.num_gpus  # per 100 batches
                ) == 0 or global_step_v == 0:
                    vtrain_loss_value, vtrain_summary_str = sess.run(
                        [loss, vtrain_summary_op],
                        feed_dict={
                            dataset.handle: vtrain_handle,
                            is_training: False
                        })
                    vtest_loss_value, vtest_summary_str = sess.run(
                        [loss, vtest_summary_op],
                        feed_dict={
                            dataset.handle: vtest_handle,
                            is_training: False
                        })
                    examples_per_sec = FLAGS.batch_size / duration
                    sec_per_batch = duration
                    format_str = (
                        '%s: batch_id %d, epoch_id %d, loss = %f vtrain_loss_value = %f vtest_loss_value = %f (%.2f examples/sec; %.2f sec/batch)'
                    )
                    print(format_str %
                          (datetime.now(), batch_idx, batch_idx //
                           num_batch_per_epoch, loss_value, vtrain_loss_value,
                           vtest_loss_value, examples_per_sec, sec_per_batch))
                    summary_writer.add_summary(summary_str, batch_idx)
                    vtrain_summary_writer.add_summary(vtrain_summary_str,
                                                      batch_idx)
                    vtest_summary_writer.add_summary(vtest_summary_str,
                                                     batch_idx)
                # Save the model checkpoint periodically.
                if global_step_v % (2 * num_batch_per_epoch *
                                    FLAGS.num_gpus) == 0:  # per 2 epochs
                    checkpoint_path = os.path.join(FLAGS.train_dir,
                                                   'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=batch_idx)
                    print('successfully save model!')
            # Stop the threads
            coord.request_stop()
            # Wait for threads to stop
            coord.join(threads)