示例#1
0
def train(epoch, train_obj, gt_obj, i):
    model.train()
    epoch_loss = 0
    m = len(train_obj) // batch_size
    #shuffle training set
    batch_loss = 0
    for idy in range(m):
        train_x, train_y = get_batch(train_obj, gt_obj, batch_size, idy, True)
        optimizer.zero_grad()
        train_x = train_x.to(device)
        train_y = train_y.to(device)
        loss = criterion(model(train_x), train_y)
        epoch_loss += loss.item()
        batch_loss += loss.item()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        if (idy % 10 == 0):
            if (idy > 0):
                batch_loss /= 10
            print(
                "===> Epoch[{}], round[{}]({}/{}): Batch Loss: {:.8f}".format(
                    epoch, i, idy, m, batch_loss))
            batch_loss = 0
        if (idy == m - 1):
            batch_loss /= (m % 100)
            print(
                "===> Epoch[{}], round[{}]({}/{}): Batch Loss: {:.8f}".format(
                    epoch, i, idy, m, batch_loss))
    print("===> Epoch {} Complete: Avg. Loss: {:.8f}".format(
        epoch, epoch_loss / (m)))
    training_loss.append(epoch_loss / (m))
def evaluate(segment, n_batches):

    loss_lst_velocity = []
    loss_lst_heading = []

    for iteration in range(0, n_batches):
        images, targets, targets_past = get_batch(128, segment,
                                                  num_images_back,
                                                  num_targets_forward,
                                                  num_targets_back)

        images = norm(to_var(torch.from_numpy(images)))
        targets = to_var(torch.from_numpy(targets))
        targets_past = to_var(torch.from_numpy(targets_past))

        pred = predictor(images, targets_past)

        loss_v = torch.abs(pred - targets)[:, 0::2].mean()
        loss_h = torch.abs(pred - targets)[:, 1::2].mean()

        loss_lst_velocity.append(loss_v)
        loss_lst_heading.append(loss_h)

    return (sum(loss_lst_velocity) /
            len(loss_lst_velocity)).data.cpu().numpy().tolist(), (
                sum(loss_lst_heading) /
                len(loss_lst_heading)).data.cpu().numpy().tolist()
def train_model(training_iterations, batch_size, train_data_file):
	accuracy_list, entropy_list = [], []
	x_data, y_data = load_train_data(train_data_file, train_with_only_known_age_data)
	#print("Length of input data is: ", len(x_data))
	start_time = time.time()
	saver = tf.train.Saver()
	for i in range(training_iterations):
		x_batch, y_batch = get_batch(x_data, y_data, batch_size)
		training_data = {x: x_batch, y_: y_batch}
		accrcy, entropy = sess.run([accuracy, cross_entropy], feed_dict=training_data)

		#Backpropagation
		sess.run(train_step, feed_dict=training_data)
		accuracy_list.append(accrcy)
		entropy_list.append(entropy)

		# Saving checkpoints to load trained model later
		directory = "checkpoints/trained_model"
		if not os.path.exists(directory): os.makedirs(directory)
		saver.save(sess, directory, global_step=checkpoint_every)

		# printing the training performance
		if i % 100 == 0:
			print("Accuracy after %s training steps is: %s" % (i, accrcy))
	print("")
	print("Training process is done in time: ", time.time() - start_time, "seconds.")
	return accuracy_list, entropy_list		
示例#4
0
def fill_feed_dict(data_set, input_placeholder, output_placeholder):
    [input_images, output_images, input_mask, output_mask]  = data_loader.get_batch(batch_size = BATCH_SIZE, image_size = IMAGE_SIZE)
    input_images.reshape((BATCH_SIZE, -1))
    output_images.reshape((BATCH_SIZE, -1))
    input_mask.reshape((BATCH_SIZE, -1))
    output_mask.reshape((BATCH_SIZE, -1))

    input_vec = np.concatenate((input_images, output_images, input_mask), axis = 1)

    # Create the feed_dict for the placeholders filled with the next
    # `batch size` examples.
    feed_dict = {
      input_placeholder: input_vec,
      output_placeholder: output_mask,
示例#5
0
    def test_get_batch_B(self):
        d = th.LongTensor([[0, 6, 12, 18],
                           [1, 7, 13, 19],
                           [2, 8, 14, 20],
                           [3, 9, 15, 21],
                           [4, 10, 16, 22],
                           [5, 11, 17, 23]])
        i = 4
        timesteps = 2
        xe = th.LongTensor([[4, 10, 16, 22]])
        ye = th.LongTensor([[5, 11, 17, 23]])
        x, y, _ = dl.get_batch(d, i, timesteps)

        self.assertEqual(int(th.all(th.eq(xe, x))), 1)
        self.assertEqual(int(th.all(th.eq(ye, y))), 1)
def train_model(training_iterations, batch_size, train_data_file):
    accuracy_list, cross_entropy_list = [], []
    xs_data, ys_data = load_train_data(train_data_file, True)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        for i in range(training_iterations):
            x_batch, y_batch = get_batch(xs_data, ys_data, batch_size)
            training_data = {x: x_batch, y_: y_batch}
            accrcy, s_cross = sess.run([accuracy, cross_entropy],
                                       feed_dict=training_data)

            #Backpropagation
            sess.run(train_step, feed_dict=training_data)
            accuracy_list.append(accrcy)
            cross_entropy_list.append(s_cross)
        return accuracy_list, cross_entropy_list
示例#7
0
def run_training():
    train_dir = 'C:/datasets/emnist/train/'
    logs_summary_dir = './log/summary/train/'
    check_point_path = './log/model/'

    train, train_labels = data_loader.get_files(train_dir)
    train_batch, train_label_batch = data_loader.get_batch(
        train, train_labels, IMG_H, IMG_W, BATCH_SIZE, CAPACITY)
    train_logits, _ = model.inference(train_batch, BATCH_SIZE, N_CLASSES)
    train_loss = model.losses(train_logits, train_label_batch)
    train_op = model.training(train_loss, LEARNING_RATE)
    train_acc = model.evaluation(train_logits, train_label_batch)

    summery_op = tf.summary.merge_all()
    with tf.Session() as sess:
        train_writer = tf.summary.FileWriter(logs_summary_dir,
                                             graph=sess.graph,
                                             session=sess)
        saver = tf.train.Saver(max_to_keep=1)
        if os.path.exists(os.path.join(check_point_path, 'checkpoint')):
            saver.restore(sess, tf.train.latest_checkpoint(check_point_path))
        else:
            sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            for step in range(MAX_STEPS):
                if coord.should_stop(): break
                _, tra_loss, tra_acc = sess.run(
                    [train_op, train_loss, train_acc])

                if step % 50 == 0:
                    print('The training loss and acc respectively: %.2f %.2f' %
                          (tra_loss, tra_acc))
                    summary_total = sess.run(summery_op)
                    train_writer.add_summary(summary_total, global_step=step)

                if step % 2000 == 0 or (step + 1) == MAX_STEPS:
                    saver.save(sess, check_point_path, global_step=step)

        except tf.errors.OutOfRangeError:
            print('training done!')
        finally:
            coord.request_stop()
    coord.join(threads)
示例#8
0
    def train(self):
        self.epoch += 1
        self.model.train()
        total_loss = 0.
        start_time = time.time()
        number_tokens = len(self.corpus.vocab)
        hidden = self.model.init_hidden(self.batch_size)

        for batch, i in enumerate(
                range(0,
                      self.train_data.size(0) - 1, self.seq_len)):
            data, targets = get_batch(self.train_data, i, seq_len=self.seq_len)
            # Starting each batch, we detach the hidden state from how it was previously produced.
            # If we didn't, the model would try backpropagating all the way to start of the dataset.
            self.model.zero_grad()
            hidden = repackage_hidden(hidden)
            output, hidden = self.model(data, hidden)
            loss = self.criterion(output.view(-1, number_tokens),
                                  targets.long())
            loss.backward()

            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.clip_grad)
            for p in self.model.parameters():
                p.data.add_(
                    -self.learning_rate,
                    p.grad.data)  # Is this just Stochastic Gradient Descent?

            total_loss += loss.item()

            if batch % self.log_interval == 0 and batch > 0:
                cur_loss = total_loss / self.log_interval
                elapsed = time.time() - start_time
                print(
                    '| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                        self.epoch, batch,
                        len(self.train_data) // self.seq_len,
                        self.learning_rate, elapsed * 1000 / self.log_interval,
                        cur_loss, math.exp(cur_loss)))
                if self.logging:
                    self.logger.log_train(self.epoch, batch, cur_loss)
                total_loss = 0
                start_time = time.time()
def train():
    config = load_yaml_config("config.yml")
    display_step = config["model"]["display_step"]
    evaluate_step = config["model"]["evaluate_step"]
    save_step = config["model"]["save_step"]
    checkpoint_path = config["model"]["checkpoint_path"]
    pickle_path = config["data"]["pickle_path"]
    pb_path = config["model"]["pb_path"]
    model = TodAutoEncoder(config)
    print(model.input_x)
    print(model.loss)

    with open(pickle_path, "rb") as f:
        _ = pickle.load(f)
        _, sparse_test = pickle.load(f)
    card, sparse = zip(*sparse_test)
    test = dense_transform(list(sparse))

    sess = get_session()
    sess.run(tf.global_variables_initializer())

    batch_data = get_batch()
    for batch in batch_data:
        _, loss_train, step = model.step(sess, batch)
        if step % display_step == 0:
            print("step: %d => loss: %.4f" % (step, loss_train))
        if step % evaluate_step == 0:
            _, loss_test, _ = model.step(sess, test)
            print("{0:-^30}".format("evaluation loss: %.4f" % loss_test))
            print("")
        if step % save_step == 0:
            model.save(sess, checkpoint_path)
    model.save(sess, checkpoint_path)

    shutil.rmtree(pb_path, ignore_errors=True)
    builder = tf.saved_model.builder.SavedModelBuilder(pb_path)
    inputs = {'input_x': tf.saved_model.utils.build_tensor_info(model.input_x)}
    outputs = {'output': tf.saved_model.utils.build_tensor_info(model.loss)}
    signature = tf.saved_model.signature_def_utils.build_signature_def(
        inputs=inputs,
        outputs=outputs,
        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)

    builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING], {'my_signature': signature})
    builder.save()
示例#10
0
    def test(self):
        self.inputs = data_loader.get_batch(self.config)
        self.model_setup()
        saver = tf.train.Saver()

        print("---TESTING THE RESULTS---")
        if self._epoch is None:
            self._test_single(saver)
        elif self._epoch.isdigit():
            self._test_single(saver, epoch=int(self._epoch))
        elif ',' in self._epoch:
            range_str = list(map(int, str(self._epoch).split(',')))
            first_epoch, last_epoch = range_str[:2]
            step = range_str[2] if len(range_str) == 3 else 1
            self._output_dir = os.path.join(self._output_dir, 'test')
            for e in range(first_epoch, last_epoch, step):
                self._test_single(saver, e)
        else:
            raise AttributeError(
                'Stop testing. Unexpected epoch description parameter.')
        print("---TESTING FINISHED---")
示例#11
0
def evaluate(model,
             corpus,
             criterion,
             device,
             batch_size=25,
             seq_len=35,
             set='valid'):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0.
    ntokens = len(corpus.vocab)
    data_source = batchify(getattr(corpus, set), batch_size, device)
    hidden = model.init_hidden(batch_size)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, seq_len):
            data, targets = get_batch(data_source, i, seq_len=seq_len)
            output, hidden = model(data, hidden)
            hidden = repackage_hidden(hidden)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat,
                                                targets.long()).item()
    return total_loss / (len(data_source) - 1)
示例#12
0
def test_data_generator():
    import tensorflow as tf
    from data_loader import get_batch
    from glob import glob
    batches, num_batches, tot_samples = get_batch(
        filenames=glob('games_data/*.csv'), batch_size=3)
    print(f'>>>>>>>>>>>>>>> batches: {batches}, num_batches: {num_batches}')

    # create a iterator of the correct shape and type
    iter = tf.data.Iterator.from_structure(batches.output_types,
                                           batches.output_shapes)
    xs = iter.get_next()
    train_init_op = iter.make_initializer(batches)
    print(f'=============== xs: {xs}')

    with tf.Session() as sess:
        sess.run(train_init_op)
        ops = (xs[0], xs[1], xs[2], xs[3])
        _board, _from, _to, _res = sess.run(ops)
        print(f'shape of board: {_board.shape}')
        print(f'from action: {_from}')
        print(f'to action: {_to}')
        print(f'result: {_res}')
示例#13
0
def main(arguments):
    # Get the data we need from the checkpoint
    model, corpora = load_model_corpora(arguments.checkpoint)
    # load the sentences
    word2idx = Word2idx(corpora.vocab)
    sentences, lengths = tokenize(arguments.text_file, word2idx)
    lengths = np.cumsum([0] + lengths[:-1])
    if arguments.nonce:
        gold = nonce_gold(arguments.gold_file)
    else:
        # load the number agreement data, which should be tab sepparated.
        gold = pd.read_csv(arguments.gold_file, delimiter='\t',
                    names=['context','right','wrong','attractors'])
    # Get the location of the target verbs.
    gold['idx'] = gold['context'] + lengths
    # Get the predictions.
    model.eval()
    sentences = batchify(sentences, 1)
    hidden = model.init_hidden(1)
    input, _ = get_batch(sentences, 0, len(sentences))
    output, hidden = model(input, hidden)
    results = gold.apply(lambda x: result_(x, output, word2idx), axis=1)
    checkpoint = load_checkpoint(arguments.checkpoint)
    return results, checkpoint['valid_loss']
示例#14
0
def main():

    word2int, int2word = get_vocab()
    print("vocabulary loaded")

    g = Graph(word2int, int2word, is_training=True)
    print("Graph loaded")

    saver1 = tf.train.Saver()
    sess1 = tf.Session()
    saver1.restore(sess1, tf.train.latest_checkpoint(hp.model1_path))

    saver2 = tf.train.Saver()
    sess2 = tf.Session()
    saver2.restore(sess2, tf.train.latest_checkpoint(hp.model2_path))

    comments_list = [  #'000333.xlsx',
        #'000423.xlsx',
        #'000651.xlsx',
        #'000858.xlsx',
        #'000868.xlsx',
        #'002607.xlsx',
        #'002739.xlsx', File is not a zip file
        '300027.xlsx',
        #'600518.xlsx', No such file or directory: './data/101/comments/600518.xlsx'
        '600519.xlsx'
    ]
    ##
    results_list = [  #'000333.txt',
        #'000423.txt',
        #'000651.txt',
        #'000858.txt',
        #'000868.txt',
        #'002607.txt',
        #'002739.txt',
        '300027.txt',
        #'600518.txt',
        '600519.txt'
    ]
    """
    comments_list = ['000662.txt',
                     '002212.txt',
                     '002298.txt',
                     '300168.txt',
                     '600570.txt',
                     '600571.txt',
                     '600588.txt',
                     '600718.txt',
                     '601519.txt',
                     '603881.txt',
                     ]
    results_list = ['000662.txt',
                    '002212.txt',
                    '002298.txt',
                    '300168.txt',
                    '600570.txt',
                    '600571.txt',
                    '600588.txt',
                    '600718.txt',
                    '601519.txt',
                    '603881.txt',
                    ]
    """
    kk = 1
    for (comment_dir, result_dir) in zip(comments_list, results_list):
        print(kk, comment_dir)

        comment_name = hp.emotion_comments_path + comment_dir
        preprocessed_data, date_of_comments = read_unlabeled_stock(
            comment_name)
        print(preprocessed_data[3])
        print(date_of_comments[3])

        total = int(len(preprocessed_data) / hp.Sbatch_size)
        preprocessed_data = preprocessed_data[:total * hp.Sbatch_size]
        date_of_comments = date_of_comments[:total * hp.Sbatch_size]

        flags1 = []
        flags2 = []
        batches_unlabeled = get_batch(preprocessed_data, preprocessed_data,
                                      word2int)
        j = 1
        for batch_x, _ in batches_unlabeled:
            if j % 300 == 0:
                print('正在预测:{}/{}'.format(j, total))

            pre_flag1 = sess1.run(g.preds, feed_dict={g.x: batch_x})
            pre_flag2 = sess2.run(g.preds, feed_dict={g.x: batch_x})
            flags1.extend(pre_flag1)
            flags2.extend(pre_flag2)
            j += 1
        pos_num = 0
        neg_num = 0
        none_num = 0
        date2sentimentCount = dict()
        for (flag1, flag2, day) in zip(flags1, flags2, date_of_comments):
            if day not in date2sentimentCount:
                date2sentimentCount[day] = [0, 0, 0]  # NONE, POS, NEG
            if flag1 == 0:
                date2sentimentCount[day][0] += 1  # num_NONE + 1
                none_num += 1
            else:
                if flag2 == 1:
                    date2sentimentCount[day][1] += 1  #num_POS + 1
                    pos_num += 1
                else:
                    date2sentimentCount[day][2] += 1  # num_NEG + 1
                    neg_num += 1

        log_sentiment_classify.write(
            "{}: stock:{} pos {}, neg {}, total {}\n".format(
                datetime.datetime.now().isoformat(), result_dir, pos_num,
                neg_num, len(flags1)))

        print('开始写入啦!')
        with open(os.path.join(hp.new_pos_neg_comments_path, result_dir),
                  'w') as f:
            s = sorted(date2sentimentCount.items(), key=lambda x: x[0])

            for (day, num) in s:
                #print(day, num)
                f.write(str(day))
                f.write('\t')
                f.write(str(num[0]))  # NONE
                f.write('\t')
                f.write(str(num[1]))  # 正
                f.write('\t')
                f.write(str(num[2]))  # 负
                f.write('\n')

        # TODO: Display checked comments
        #logger.info("pos neg for %s" % comment_dir)
        kk += 1
示例#15
0
def train():
    tokenizer = pickle.load(open(tokenizer_path, "rb"))
    token2index = tokenizer.word_index

    if use_word2vec:
        embedding = load_word_embedding(vocab_size, token2index)

    lstm = Model(num_layers=num_layers,
                 seq_length=seq_length,
                 embedding_size=embedding_size,
                 vocab_size=vocab_size,
                 rnn_size=rnn_size,
                 use_bilstm=use_bilstm,
                 label_size=label_size,
                 embedding=embedding if use_word2vec else None)

    print("{0:-^40}".format("需要训练的参数"))
    for var in tf.trainable_variables():
        print(var.name, var.shape)

    session_conf = tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=False)
    sess = tf.Session(config=session_conf)

    with sess.as_default():
        global_step = tf.Variable(0, name="global_step", trainable=False)
        decaylearning_rate = tf.train.exponential_decay(
            learning_rate, global_step, 100, 0.99)
        optimizer = tf.train.AdamOptimizer(
            learning_rate=decaylearning_rate).minimize(lstm.loss,
                                                       global_step=global_step)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)

        sess.run(tf.global_variables_initializer())
        merged = tf.summary.merge_all()
        shutil.rmtree(summary_path, ignore_errors=True)
        writer = tf.summary.FileWriter(summary_path, sess.graph)

        def train_step(batch, label):
            feed_dict = {
                lstm.input_x: batch,
                lstm.input_y: label,
                lstm.dropout_keep_prob: 0.7
            }
            _, step, loss, accuracy = sess.run(
                [optimizer, global_step, lstm.loss, lstm.accuracy],
                feed_dict=feed_dict)

            return step, loss, accuracy

        def dev_step(batch, label):
            feed_dict = {
                lstm.input_x: batch,
                lstm.input_y: label,
                lstm.dropout_keep_prob: 1
            }
            step, loss, accuracy = sess.run(
                [global_step, lstm.loss, lstm.accuracy], feed_dict=feed_dict)
            print("{0:-^40}".format("evaluate"))
            print("step:{}".format(step), "==>", "loss:%.5f" % loss,
                  "accuracy:%.5f" % accuracy)

        batches = get_batch()
        x_dev, y_dev = pickle.load(open(test_path, "rb"))
        y_dev = np.array(y_dev).reshape(-1, 1)

        print("{0:-^40}".format("模型训练"))
        for data in batches:
            x_train, y_train = zip(*data)
            y_train = np.array(y_train).reshape(-1, 1)
            step, loss, accuracy = train_step(x_train, y_train)
            current_step = tf.train.global_step(sess, global_step)
            result = sess.run(merged,
                              feed_dict={
                                  lstm.input_x: x_train,
                                  lstm.input_y: y_train,
                                  lstm.dropout_keep_prob: 0.5
                              })
            writer.add_summary(result, current_step)

            if current_step % 10 == 0:
                print("global step:{}".format(step), "==>", "loss:%.5f" % loss,
                      "accuracy:%.5f" % accuracy)

            if current_step % 500 == 0:
                dev_step(x_dev, y_dev)
                print("")
        dev_step(x_dev, y_dev)

        saver.save(sess, cktp_path, global_step=current_step)

        shutil.rmtree(pb_path, ignore_errors=True)
        builder = tf.saved_model.builder.SavedModelBuilder(pb_path)
        inputs = {
            'input_x':
            tf.saved_model.utils.build_tensor_info(lstm.input_x),
            'dropout_keep_prob':
            tf.saved_model.utils.build_tensor_info(lstm.dropout_keep_prob)
        }
        outputs = {
            'output': tf.saved_model.utils.build_tensor_info(lstm.probs)
        }
        signature = tf.saved_model.signature_def_utils.build_signature_def(
            inputs=inputs,
            outputs=outputs,
            method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)

        builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING],
                                             {'my_signature': signature})
        builder.save()
示例#16
0
    # 获取训练集、验证集、测试集中最大句子的长度
    max_length = processing.get_max_length(train_x)
    # dev_length = processing.get_max_length(dev_x)
    # test_length = processing.get_max_length(test_x)
    #
    # max_length = max(train_length,dev_length,test_length)
    # # print(max_length)




    train_x_tensor = data_loader.pad(train_x, word2id, "-pad-",max_length)
    train_y_tensor = torch.from_numpy(np.array(train_y))

    trainloader = data_loader.get_batch(64,train_x_tensor, train_y_tensor)

    # model
    cnnModel = CNN.CNNlayer(len(word2id),100,2,[3,4,5],output_size=2)
    # poolmodel = CNN.CNN(len(word2id),100,2,[3,4],1)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(cnnModel.parameters(),lr=0.01)

    train_loss_plot = []
    train_acc_plot = []
    x_plot = []

    for epoch in range(100):  # loop over the dataset multiple times
        cnnModel.train()
        x_plot.append(epoch+1)
        running_loss = 0.0
示例#17
0
def test(configuration):
    """
    test loop to produce images with multiple models
    :param configuration:
    :return:
    """
    encoder = net.get_trained_encoder(configuration)
    decoder = net.get_trained_decoder(configuration)

    model_path_list = configuration['model_path_list']
    content_images_path = configuration['content_images_path']
    style_images_path = configuration['style_images_path']
    loader = configuration['loader']
    image_saving_path = configuration['image_saving_path']
    mode_list = configuration['mode_list']

    model_list = [0 for _ in range(len(model_path_list))]
    for i in range(len(model_list)):
        moment_alignment_model = net.get_moment_alignment_model(
            configuration,
            moment_mode=mode_list[i],
            use_list=True,
            list_index=i)
        checkpoint = torch.load(model_path_list[i], map_location='cpu')
        moment_alignment_model.load_state_dict(checkpoint)
        model_list[i] = moment_alignment_model

    number_content_images = len(os.listdir(content_images_path))
    number_style_images = len(os.listdir(style_images_path))
    content_image_files = [
        '{}/{}'.format(content_images_path,
                       sorted(os.listdir(content_images_path))[i])
        for i in range(number_content_images)
    ]
    style_image_files = [
        '{}/{}'.format(style_images_path,
                       sorted(os.listdir(style_images_path))[i])
        for i in range(number_style_images)
    ]

    for i in range(number_content_images):
        for j in range(number_style_images):
            print('at image {}'.format(i))
            with torch.no_grad():
                content_image = data_loader.image_loader(
                    content_image_files[i], loader)
                style_image = data_loader.image_loader(style_image_files[j],
                                                       loader)

                result_images = [0 for _ in range(len(model_list))]
                for k in range(len(model_list)):
                    content_feature_map_batch_loader = data_loader.get_batch(
                        encoder(content_image)['r41'], 512)
                    style_feature_map_batch_loader = data_loader.get_batch(
                        encoder(style_image)['r41'], 512)
                    content_feature_map_batch = next(
                        content_feature_map_batch_loader).to(device)
                    style_feature_map_batch = next(
                        style_feature_map_batch_loader).to(device)

                    style_feature_map_batch_moments = utils.compute_moments_batches(
                        style_feature_map_batch, last_moment=7)
                    content_feature_map_batch_moments = utils.compute_moments_batches(
                        content_feature_map_batch, last_moment=7)

                    result_images[k] = decoder(model_list[k](
                        content_feature_map_batch,
                        content_feature_map_batch_moments,
                        style_feature_map_batch_moments).view(1, 512, 32, 32))
                    result_images[k] = result_images[k].squeeze(0)

                # save all images in one row
                u.save_image(
                    [
                        data_loader.imnorm(content_image, None),
                        data_loader.imnorm(style_image, None)
                    ] + result_images,
                    '{}/moment_alignment_test_image_A_{}_{}.jpeg'.format(
                        image_saving_path, i, j),
                    normalize=False,
                    scale_each=False,
                    pad_value=1)

                # save all images in two rows
                u.save_image(
                    [
                        data_loader.imnorm(content_image, None),
                        data_loader.imnorm(style_image, None)
                    ] + [torch.ones(3, 256, 256)
                         for _ in range(2)] + result_images,
                    '{}/moment_alignment_test_image_B_{}_{}.jpeg'.format(
                        image_saving_path, i, j),
                    normalize=False,
                    scale_each=False,
                    pad_value=1,
                    nrow=4)

                # save all result images in one row
                u.save_image(
                    result_images,
                    '{}/moment_alignment_test_image_C_{}_{}.jpeg'.format(
                        image_saving_path, i, j),
                    normalize=False,
                    scale_each=False,
                    pad_value=1)

                # save all result images in one row + content image
                u.save_image(
                    [data_loader.imnorm(content_image, None)] + result_images,
                    '{}/moment_alignment_test_image_D_{}_{}.jpeg'.format(
                        image_saving_path, i, j),
                    normalize=False,
                    scale_each=False,
                    pad_value=1)

                # save all result images in one row + style image
                u.save_image(
                    [data_loader.imnorm(style_image, None)] + result_images,
                    '{}/moment_alignment_test_image_E_{}_{}.jpeg'.format(
                        image_saving_path, i, j),
                    normalize=False,
                    scale_each=False,
                    pad_value=1)
        loss_lst_velocity.append(loss_v)
        loss_lst_heading.append(loss_h)

    return (sum(loss_lst_velocity) /
            len(loss_lst_velocity)).data.cpu().numpy().tolist(), (
                sum(loss_lst_heading) /
                len(loss_lst_heading)).data.cpu().numpy().tolist()


if __name__ == "__main__":

    for iteration in range(0, 2000):

        images, targets, targets_past = get_batch(128, "train",
                                                  num_images_back,
                                                  num_targets_forward,
                                                  num_targets_back)

        images = norm(to_var(torch.from_numpy(images)))
        targets = to_var(torch.from_numpy(targets))
        targets_past = to_var(torch.from_numpy(targets_past))

        pred = predictor(images, targets_past)

        loss = torch.abs(pred - targets).mean()

        predictor.zero_grad()
        loss.backward()
        optim.step()

        if iteration % 100 == 0:
示例#19
0
import tensorflow as tf
import data_loader

N_CLASSES = 62
IMG_W = 28
IMG_H = 28
BATCH_SIZE = 32
CAPACITY = 4 * BATCH_SIZE
MAX_STEPS = 10000

eval_log_dir = './log/summary/test/'
eval_dir = 'C:/datasets/emnist/test/'
check_point_path = './log/model/'
eva, eva_labels = data_loader.get_files(eval_dir)
eval_batch, eval_label_batch = data_loader.get_batch(eva, eva_labels, IMG_H,
                                                     IMG_W, BATCH_SIZE,
                                                     CAPACITY)
eval_logits, _ = model.inference(eval_batch,
                                 BATCH_SIZE,
                                 N_CLASSES,
                                 training=False)
eval_acc = model.evaluation(eval_logits, eval_label_batch)
eval_loss = model.losses(eval_logits, eval_label_batch)
summery_op = tf.summary.merge_all()


def run_testing():
    TOTAL_ACC_SUM = 0
    TOTAL_LOSS_SUM = 0
    with tf.Session() as sess:
        eval_writer = tf.summary.FileWriter(eval_log_dir,
示例#20
0
def test(configuration):
    """
    test loop
    :param configuration: the config file
    :return:
    """
    analytical_ada_in_module = net.AdaptiveInstanceNormalization()
    encoder = net.get_trained_encoder(configuration)
    decoder = net.get_trained_decoder(configuration)

    pretrained_model_path = configuration['pretrained_model_path']
    print('loading the moment alignment model from {}'.format(pretrained_model_path))

    content_images_path = configuration['content_images_path']
    style_images_path = configuration['style_images_path']

    loader = configuration['loader']
    unloader = configuration['unloader']

    image_saving_path = configuration['image_saving_path']
    moment_mode = configuration['moment_mode']

    moment_alignment_model = net.get_moment_alignment_model(configuration, moment_mode)
    print(moment_alignment_model)

    checkpoint = torch.load(pretrained_model_path, map_location=device)
    moment_alignment_model.load_state_dict(checkpoint)

    aligned_moment_loss = net.get_loss(configuration, moment_mode=moment_mode, lambda_1=0, lambda_2=10)

    number_content_images = len(os.listdir(content_images_path))
    number_style_images = len(os.listdir(style_images_path))

    content_image_files = ['{}/{}'.format(content_images_path, os.listdir(content_images_path)[i])
                           for i in range(number_content_images)]
    style_image_files = ['{}/{}'.format(style_images_path, os.listdir(style_images_path)[i])
                         for i in range(number_style_images)]

    for i in range(number_style_images):
        print("test_image {} at {}".format(i + 1, style_image_files[i]))

    for i in range(number_content_images):
        print("test_image {} at {}".format(i + 1, content_image_files[i]))

    for j in range(number_content_images):
        for i in range(number_style_images):
            style_image = data_loader.image_loader(style_image_files[i], loader)
            content_image = data_loader.image_loader(content_image_files[j], loader)
            with torch.no_grad():
                content_feature_maps = encoder(content_image)['r41']
                style_feature_maps = encoder(style_image)['r41']

                content_feature_map_batch_loader = data_loader.get_batch(content_feature_maps, 512)
                style_feature_map_batch_loader = data_loader.get_batch(style_feature_maps, 512)

                content_feature_map_batch = next(content_feature_map_batch_loader).to(device)
                style_feature_map_batch = next(style_feature_map_batch_loader).to(device)

                if use_MA_module:
                    style_feature_map_batch_moments = u.compute_moments_batches(style_feature_map_batch, last_moment=7)
                    content_feature_map_batch_moments = u.compute_moments_batches(content_feature_map_batch, last_moment=7)

                    out = moment_alignment_model(content_feature_map_batch,
                                                 content_feature_map_batch_moments,
                                                 style_feature_map_batch_moments,
                                                 is_test=True)

                    out_feature_map_batch_moments = u.compute_moments_batches(out, last_moment=7)

                    print_some_moments(style_feature_map_batch_moments, content_feature_map_batch_moments, out_feature_map_batch_moments)

                    loss, moment_loss, reconstruction_loss = aligned_moment_loss(content_feature_map_batch,
                                                                                 style_feature_map_batch,
                                                                                 content_feature_map_batch_moments,
                                                                                 style_feature_map_batch_moments,
                                                                                 out,
                                                                                 is_test=True)

                    print('loss: {}, moment_loss: {}, reconstruction_loss:{}'.format(
                        loss.item(), moment_loss.item(), reconstruction_loss.item()))
                else:
                    analytical_feature_maps = analytical_ada_in_module(content_feature_map_batch, style_feature_map_batch)
                    out = analytical_feature_maps

                utils.save_image([data_loader.imnorm(content_image, unloader),
                                  data_loader.imnorm(style_image, unloader),
                                  data_loader.imnorm(decoder(out.view(1, 512, 32, 32)), None)],
                                  '{}/A_image_{}_{}.jpeg'.format(image_saving_path, i,j), normalize=False)

                utils.save_image([data_loader.imnorm(decoder(out.view(1, 512, 32, 32)), None)],
                                 '{}/B_image_{}_{}.jpeg'.format(image_saving_path, i, j), normalize=False)
示例#21
0
    def train(self):
        self.inputs = data_loader.get_batch(self.config)

        self.model_setup()
        self.compute_losses()

        init = (tf.global_variables_initializer(),
                tf.local_variables_initializer())
        saver = tf.train.Saver(max_to_keep=21)

        max_images = self.config['n_imgs']
        tf_config = tf.ConfigProto(
            gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.5),
            device_count={'GPU': 1})

        with tf.Session(config=tf_config) as sess:
            sess.run(init)

            if self._to_restore:
                chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir)
                saver.restore(sess, chkpt_fname)

            writer = tf.summary.FileWriter(self._output_dir)
            os.makedirs(self._output_dir, exist_ok=True)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            # Training Loop
            start_step = sess.run(tf.train.get_global_step())
            print('Starting at epoch =', start_step)
            pr_bar = tqdm(
                range(start_step, self._max_step),
                bar_format='{desc}|{bar}|{percentage:3.0f}% ETA: {remaining}')
            for epoch in pr_bar:
                pr_bar.set_description('Epoch %d/%d' % (epoch, self._max_step))
                if epoch % 10 == 0:
                    saver.save(sess,
                               os.path.join(self._checkpoint_dir, "cyclegan"),
                               global_step=epoch)
                    self.save_images(sess, epoch, self._output_dir)

                # Dealing with the learning rate as per the epoch number
                if epoch >= self._max_step:
                    break
                elif epoch < 100:
                    curr_lr = self._base_lr
                else:
                    curr_lr = self._base_lr - \
                              self._base_lr * (epoch - 100) / 100

                for i in range(0, max_images):
                    inputs = sess.run(self.inputs)
                    in_a = np.expand_dims(inputs[0], 0)
                    in_b = np.expand_dims(inputs[1], 0)

                    # Optimizing the G_A network
                    _, fake_B_temp, summary_str = sess.run(
                        [
                            self.g_A_trainer, self.fake_images_b,
                            self.g_A_loss_summ
                        ],
                        feed_dict={
                            self.input_a: in_a,
                            self.input_b: in_b,
                            self.learning_rate: curr_lr
                        })
                    writer.add_summary(summary_str, epoch * max_images + i)

                    fake_B_temp1 = self.fake_image_pool(
                        self.num_fake_inputs, fake_B_temp, self.fake_images_B)

                    # Optimizing the D_B network
                    _, summary_str = sess.run(
                        [self.d_B_trainer, self.d_B_loss_summ],
                        feed_dict={
                            self.input_a: in_a,
                            self.input_b: in_b,
                            self.learning_rate: curr_lr,
                            self.fake_pool_B: fake_B_temp1
                        })
                    writer.add_summary(summary_str, epoch * max_images + i)

                    # Optimizing the G_B network
                    _, fake_A_temp, summary_str = sess.run(
                        [
                            self.g_B_trainer, self.fake_images_a,
                            self.g_B_loss_summ
                        ],
                        feed_dict={
                            self.input_a: in_a,
                            self.input_b: in_b,
                            self.learning_rate: curr_lr
                        })
                    writer.add_summary(summary_str, epoch * max_images + i)

                    fake_A_temp1 = self.fake_image_pool(
                        self.num_fake_inputs, fake_A_temp, self.fake_images_A)

                    # Optimizing the D_A network
                    _, summary_str = sess.run(
                        [self.d_A_trainer, self.d_A_loss_summ],
                        feed_dict={
                            self.input_a: in_a,
                            self.input_b: in_b,
                            self.learning_rate: curr_lr,
                            self.fake_pool_A: fake_A_temp1
                        })
                    writer.add_summary(summary_str, epoch * max_images + i)

                    writer.flush()
                    self.num_fake_inputs += 1

                sess.run(tf.assign(self.global_step, epoch + 1))

            epoch = sess.run(tf.train.get_global_step())
            saver.save(sess,
                       os.path.join(self._checkpoint_dir, "cyclegan"),
                       global_step=epoch)
            self.save_images(sess, epoch, self._output_dir)

            total_t = pr_bar.last_print_t - pr_bar.start_t
            ave_time_per_iter = total_t / pr_bar.total
            pr_bar.close()
            total_str = 'Total time: %10.2f h, average time per iteration: %10.2f sec'
            print(total_str % (total_t / 3600.0, ave_time_per_iter))

            coord.request_stop()
            coord.join(threads)
            writer.add_graph(sess.graph)
示例#22
0
for index in range(var_count):
    var = trainable_vars[index]
    var_grad = var_grad_list[index]
    update.append(var.assign(var - learning_rate * var_grad))
#========

n_epoch = 50
batch_size = 150

init = tf.global_variables_initializer()
with tf.Session() as sess:
    init.run()
    for epoch in range(n_epoch):
        n_batch = X_train.shape[0] // batch_size
        for batch in range(n_batch):
            X_batch = get_batch(X_train, batch, batch_size)
            Y_batch = get_batch(y_train, batch, batch_size)
            sess.run(update, feed_dict={
                X: X_batch,
                Y: Y_batch
            })  # here we run the update, it contains all of
            # the assignment operations for each variables
        train_loss, train_acc = sess.run([loss, accuracy],
                                         feed_dict={
                                             X: X_train,
                                             Y: y_train
                                         })
        test_loss, test_acc = sess.run([loss, accuracy],
                                       feed_dict={
                                           X: X_test,
                                           Y: y_test
示例#23
0
def main():
    [a, b, c, d]  = data_loader.get_batch(batch_size = 1, image_size = 5)
def test(configuration):
    """
    test the moment alignment solution (for 2 moments) in comparison
    to the analytical solution that can be computed for mean and std
    :param configuration: the config file
    :return:
    """
    encoder = net.get_trained_encoder(configuration)
    decoder = net.get_trained_decoder(configuration)

    analytical_ada_in_module = net.AdaptiveInstanceNormalization()

    pretrained_model_path = configuration['pretrained_model_path']
    content_images_path = configuration['content_images_path']
    style_images_path = configuration['style_images_path']
    loader = configuration['loader']
    unloader = configuration['unloader']
    image_test_saving_path = configuration['image_saving_path']
    moment_mode = configuration['moment_mode']

    print('loading the moment alignment model from {}'.format(
        pretrained_model_path))
    moment_alignment_model = net.get_moment_alignment_model(
        configuration, moment_mode)
    print(moment_alignment_model)

    checkpoint = torch.load(pretrained_model_path, map_location=device)
    moment_alignment_model.load_state_dict(checkpoint)

    number_content_images = len(os.listdir(content_images_path))
    number_style_images = len(os.listdir(style_images_path))
    content_image_files = [
        '{}/{}'.format(content_images_path,
                       os.listdir(content_images_path)[i])
        for i in range(number_content_images)
    ]
    style_image_files = [
        '{}/{}'.format(style_images_path,
                       os.listdir(style_images_path)[i])
        for i in range(number_style_images)
    ]

    for i in range(number_style_images):
        print("test_image {} at {}".format(i + 1, style_image_files[i]))

    for i in range(number_content_images):
        print("test_image {} at {}".format(i + 1, content_image_files[i]))

    iterations = 0
    mean_percentages = [0, 0, 0, 0, 0, 0, 0]

    for j in range(number_content_images):
        for i in range(number_style_images):
            style_image = data_loader.image_loader(style_image_files[i],
                                                   loader)
            content_image = data_loader.image_loader(content_image_files[j],
                                                     loader)
            with torch.no_grad():
                content_feature_maps = encoder(content_image)['r41']
                style_feature_maps = encoder(style_image)['r41']

                content_feature_map_batch_loader = data_loader.get_batch(
                    content_feature_maps, 512)
                style_feature_map_batch_loader = data_loader.get_batch(
                    style_feature_maps, 512)

                content_feature_map_batch = next(
                    content_feature_map_batch_loader).to(device)
                style_feature_map_batch = next(
                    style_feature_map_batch_loader).to(device)

                style_feature_map_batch_moments = u.compute_moments_batches(
                    style_feature_map_batch)
                content_feature_map_batch_moments = u.compute_moments_batches(
                    content_feature_map_batch)

                out = moment_alignment_model(
                    content_feature_map_batch,
                    content_feature_map_batch_moments,
                    style_feature_map_batch_moments)

                result_feature_maps = out

                analytical_feature_maps = analytical_ada_in_module(
                    content_feature_map_batch, style_feature_map_batch)

                a_0, a_001, a_001_l, a_01, a_01_l, a_1, a_1_l = \
                    get_distance(analytical_feature_maps, result_feature_maps)
                iterations += 1

                mean_percentages[0] += a_0
                mean_percentages[1] += a_001
                mean_percentages[2] += a_001_l
                mean_percentages[3] += a_01
                mean_percentages[4] += a_01_l
                mean_percentages[5] += a_1
                mean_percentages[6] += a_1_l

                # u.imshow(decoder(analytical_feature_maps.view(1, 512, 32, 32)), transforms.ToPILImage())

                utils.save_image([
                    data_loader.imnorm(content_image, unloader),
                    data_loader.imnorm(style_image, unloader),
                    data_loader.imnorm(
                        decoder(result_feature_maps.view(1, 512, 32, 32)),
                        None),
                    data_loader.imnorm(
                        decoder(analytical_feature_maps.view(1, 512, 32, 32)),
                        None)
                ],
                                 '{}/A_image_{}_{}.jpeg'.format(
                                     image_test_saving_path, i, j),
                                 normalize=False,
                                 pad_value=1)

                utils.save_image([
                    data_loader.imnorm(content_image, unloader),
                    data_loader.imnorm(
                        decoder(result_feature_maps.view(1, 512, 32, 32)),
                        None),
                    data_loader.imnorm(
                        decoder(analytical_feature_maps.view(1, 512, 32, 32)),
                        None)
                ],
                                 '{}/B_image_{}_{}.jpeg'.format(
                                     image_test_saving_path, i, j),
                                 normalize=False,
                                 pad_value=1)

                utils.save_image([
                    data_loader.imnorm(style_image, unloader),
                    data_loader.imnorm(
                        decoder(result_feature_maps.view(1, 512, 32, 32)),
                        None),
                    data_loader.imnorm(
                        decoder(analytical_feature_maps.view(1, 512, 32, 32)),
                        None)
                ],
                                 '{}/C_image_{}_{}.jpeg'.format(
                                     image_test_saving_path, i, j),
                                 normalize=False,
                                 pad_value=1)

                utils.save_image([
                    data_loader.imnorm(
                        decoder(result_feature_maps.view(1, 512, 32, 32)),
                        None),
                    data_loader.imnorm(
                        decoder(analytical_feature_maps.view(1, 512, 32, 32)),
                        None)
                ],
                                 '{}/D_image_{}_{}.jpeg'.format(
                                     image_test_saving_path, i, j),
                                 normalize=False,
                                 pad_value=1)

    print('averaging percentages')
    mean_percentages = [
        mean_percentages[i] / iterations for i in range(len(mean_percentages))
    ]
    print(mean_percentages)
        word2int, int2word = get_vocab(words)
        g = Graph(word2int, int2word, is_training=False)

    saver = tf.train.Saver()
    sess = tf.Session()
    #sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))

    # 开始训练
    if train == 0:
        lr_list = []
        logger.info("开始训练 ... ")
        sess.run(tf.global_variables_initializer())
        step = 1
        for i in range(hp.Snum_epochs):
            # 训练数据生成器
            batches_train = get_batch(comments_train, labels_train, word2int)
            # 随模型进行训练 降低学习率

            acc = []
            loss = []
            for batch_x, batch_y in batches_train:
                lr = hp.Sfactor * (hp.Shidden_units**(-0.5) * min(
                    step**(-0.5), step * hp.Swarmup**(-1.5)))
                lr_list.append(lr)
                sess.run(tf.assign(g.learning_rate, lr))
                print('lr', lr)
                feed = {g.x: batch_x, g.y: batch_y}
                batch_loss, batch_accuracy, _ = sess.run(
                    [g.loss, g.accuracy, g.optimizer], feed_dict=feed)
                logger.info(
                    "{}: epoch {}, step {}, loss {:g}, acc {:g}".format(
示例#26
0
        try:
            current_phase_saver.restore(sess, params.save_path)
        except:
            print("Couldn't load current phase")
            print("Error:", sys.exc_info()[0])
            prev_phase_saver.restore(sess, params.save_path)
    except:
        print("Couldn't load previous phase")
        print("Error:", sys.exc_info()[0])
        print("COULD NOT RESTORE MODEL. TRAINING FROM SCATCH")
    running_loss = 0
    iterations = 0
    while True:
        try:
            start_time = time.time()
            next_batch = get_batch()
            # tdt, tdf, tdl, kdl, pl, irl, il = sess.run([train_disc_true, train_disc_false, train_disc_loss, trick_disc_loss, pixel_loss, id_regen_loss, id_loss],
            #                                   feed_dict={X: next_batch, training:False})
            # print(tdt, tdf, tdl, kdl, pl, irl, il)
            # tdt, tdf, tdl, kdl, pl, irl, il = sess.run([train_disc_true, train_disc_false, train_disc_loss, trick_disc_loss, pixel_loss, id_regen_loss, id_loss],
            #                                   feed_dict={X: next_batch, training:True})
            # print(tdt, tdf, tdl, kdl, pl, irl, il)
            # print(tdl + kdl + pl + irl + il)
            #l, _, pl = sess.run([loss, update_op, pixel_loss], feed_dict={X: next_batch, training:True})
            runners = [
                encoder_loss, decoder_loss, discriminator_loss, encoder_update,
                decoder_update, discriminator_update, id_kl_loss, pose_kl_loss,
                pixel_loss, id_regen_loss, id_loss, trick_disc_loss,
                train_disc_true, train_disc_false
            ]
示例#27
0
def validate(number_of_validation, criterion, encoder, moment_alignment_model,
             val_data_loader, feature_map_batch_size, writer):
    """
    the validaton loop
    :param number_of_validation: the number of the current validation (for data saving)
    :param criterion: the loss criterion
    :param encoder: the encoder network
    :param moment_alignment_model: the moment alignment model
    :param val_data_loader: the dataloader for the validation dataset
    :param feature_map_batch_size: the batch size
    :param writer: the (loss) writer
    :return:
    """
    print('validating model ...')
    iteration = 0
    total_validation_loss = 0
    while True:
        try:
            data = val_data_loader.__next__()
        except StopIteration:
            break
        except:
            print('something wrong happened with the data loader')
            continue

        # get the content_image batch
        content_image = data.get('coco').get('image')
        content_image = content_image.to(device)

        # get the style_image batch
        style_image = data.get('painter_by_numbers').get('image')
        style_image = style_image.to(device)

        with torch.no_grad():

            style_feature_maps = encoder(style_image)['r41'].to(device)
            content_feature_maps = encoder(content_image)['r41'].to(device)

            content_feature_map_batch_loader = data_loader.get_batch(
                content_feature_maps, feature_map_batch_size)
            style_feature_map_batch_loader = data_loader.get_batch(
                style_feature_maps, feature_map_batch_size)

            while True:
                try:
                    content_feature_map_batch = next(
                        content_feature_map_batch_loader).to(device)
                    style_feature_map_batch = next(
                        style_feature_map_batch_loader).to(device)
                    iteration += 1
                except StopIteration:
                    break
                except:
                    print('something wrong happened with the data loader')
                    continue

                style_feature_map_batch_moments = utils.compute_moments_batches(
                    style_feature_map_batch)
                content_feature_map_batch_moments = utils.compute_moments_batches(
                    content_feature_map_batch)

                out = moment_alignment_model(
                    content_feature_map_batch,
                    content_feature_map_batch_moments,
                    style_feature_map_batch_moments)

                loss, moment_loss, reconstruction_loss = criterion(
                    content_feature_map_batch, style_feature_map_batch,
                    content_feature_map_batch_moments,
                    style_feature_map_batch_moments, out)

                total_validation_loss += loss

                writer.write_row([
                    number_of_validation * 100 * 512,
                    loss.item(),
                    moment_loss.item(),
                    reconstruction_loss.item()
                ])

                # if iteration % 2000 == 0:
                #     print('validation loss: {:4f}'.format(loss.item()))
                #     writer.add_scalar('data/validation_loss', loss.item(),
                #                       number_of_validation * 100 * 512 + iteration)
                #     writer.add_scalar('data/validation_moment_loss', moment_loss.item(),
                #                       number_of_validation * 100 * 512 + iteration)
                #     writer.add_scalar('data/validation_reconstruction_loss', reconstruction_loss.item(),
                #                       number_of_validation * 100 * 512 + iteration)

    writer.add_scalar('data/mean_validation_loss',
                      total_validation_loss / iteration, number_of_validation)

    return total_validation_loss / iteration
示例#28
0
def train(configuration):
    """
    this is the main training loop
    :param configuration: the config file
    :return:
    """
    epochs = configuration['epochs']
    print('going to train for {} epochs'.format(epochs))

    step_printing_interval = configuration['step_printing_interval']
    print('writing to console every {} steps'.format(step_printing_interval))

    image_saving_interval = configuration['image_saving_interval']
    print('writing to console every {} steps'.format(step_printing_interval))

    epoch_saving_interval = configuration['epoch_saving_interval']
    print('saving the model every {} epochs'.format(epoch_saving_interval))

    validation_interval = configuration['validation_interval']
    print('validating the model every {} epochs'.format(validation_interval))

    image_saving_path = configuration['image_saving_path']
    print('saving images to {}'.format(image_saving_path))

    loader = configuration['loader']

    model_saving_path = configuration['model_saving_path']
    print('saving models to {}'.format(model_saving_path))

    # tensorboardX_path = configuration['tensorboardX_path']
    # writer = SummaryWriter(logdir='{}/runs'.format(tensorboardX_path))
    # print('saving tensorboardX logs to {}'.format(tensorboardX_path))

    loss_writer = LossWriter(os.path.join(
        configuration['folder_structure'].get_parent_folder(), './loss/loss'),
                             buffer_size=100)
    loss_writer.write_header(columns=[
        'epoch', 'all_training_iteration', 'loss', 'moment_loss',
        'reconstruction_loss'
    ])

    validation_loss_writer = LossWriter(
        os.path.join(configuration['folder_structure'].get_parent_folder(),
                     './loss/loss'))
    validation_loss_writer.write_header(columns=[
        'validation_iteration', 'loss', 'moment_loss', 'reconstruction_loss'
    ])

    # batch_size is the number of images to sample
    batch_size = 1
    feature_map_batch_size = int(configuration['feature_map_batch_size'])
    print('training in batches of {} feature maps'.format(
        feature_map_batch_size))

    coco_data_path_train = configuration['coco_data_path_train']
    painter_by_numbers_data_path_train = configuration[
        'painter_by_numbers_data_path_train']
    print('using {} and {} for training'.format(
        coco_data_path_train, painter_by_numbers_data_path_train))

    coco_data_path_val = configuration['coco_data_path_val']
    painter_by_numbers_data_path_val = configuration[
        'painter_by_numbers_data_path_val']
    print('using {} and {} for validation'.format(
        coco_data_path_val, painter_by_numbers_data_path_val))

    train_dataloader = data_loader.get_concat_dataloader(
        coco_data_path_train,
        painter_by_numbers_data_path_train,
        batch_size,
        loader=loader)
    print('got train dataloader')

    val_dataloader = data_loader.get_concat_dataloader(
        coco_data_path_val,
        painter_by_numbers_data_path_val,
        batch_size,
        loader=loader)
    print('got val dataloader')

    lambda_1 = configuration['lambda_1']
    print('lambda 1: {}'.format(lambda_1))

    lambda_2 = configuration['lambda_2']
    print('lambda 2: {}'.format(lambda_2))

    loss_moment_mode = configuration['moment_mode']
    net_moment_mode = configuration['moment_mode']
    print('loss is sum of the first {} moments'.format(loss_moment_mode))
    print('net accepts {} in-channels'.format(net_moment_mode))

    unloader = configuration['unloader']
    print('got the unloader')

    do_validation = configuration['do_validation']
    print('doing validation: {}'.format(do_validation))

    moment_alignment_model = net.get_moment_alignment_model(
        configuration, moment_mode=loss_moment_mode)
    print('got model')
    print(moment_alignment_model)

    decoder = net.get_trained_decoder(configuration)
    print('got decoder')
    print(decoder)

    print('params that require grad')
    for name, param in moment_alignment_model.named_parameters():
        if param.requires_grad:
            print(name)

    criterion = net.get_loss(configuration,
                             moment_mode=loss_moment_mode,
                             lambda_1=lambda_1,
                             lambda_2=lambda_2)

    print('got moment loss module')
    print(criterion)

    encoder = net.get_trained_encoder(configuration)
    print('got encoder')
    print(encoder)

    try:
        optimizer = optim.Adam(moment_alignment_model.parameters(),
                               lr=configuration['lr'])
    except:
        optimizer = optim.Adam(moment_alignment_model.module.parameters(),
                               lr=configuration['lr'])
    print('got optimizer')
    schedule = QuadraticSchedule(timesteps=10000000,
                                 initial=configuration['lr'],
                                 final=configuration['lr'] / 10.)
    print('got schedule')

    print('making iterable from train dataloader')
    train_data_loader = iter(train_dataloader)
    print('train data loader iterable')
    outer_training_iteration = -1
    all_training_iteration = 0

    number_of_validation = 0
    current_validation_loss = float('inf')

    for epoch in range(1, epochs):
        print('epoch: {}'.format(epoch))

        # this is the outer training loop (sampling images)
        print('training model ...')
        while True:
            try:
                data = train_data_loader.__next__()
                outer_training_iteration += 1
            except StopIteration:
                print('got to the end of the dataloader (StopIteration)')
                train_data_loader = iter(train_dataloader)
                break
            except:
                print('something went wrong with the dataloader, continuing')
                continue

            if do_validation:
                # validate the model every validation_interval iterations
                if outer_training_iteration % validation_interval == 0:
                    print('making iterable from val dataloader')
                    val_data_loader = iter(val_dataloader)
                    print('val data loader iterable')
                    validation_loss = validate(number_of_validation, criterion,
                                               encoder, moment_alignment_model,
                                               val_data_loader,
                                               feature_map_batch_size,
                                               validation_loss_writer)
                    number_of_validation += 1
                    if validation_loss < current_validation_loss:
                        utils.save_current_best_model(
                            epoch, moment_alignment_model,
                            configuration['model_saving_path'])
                        print('got a better model')
                        current_validation_loss = validation_loss
                        print('set the new validation loss to the current one')
                    else:
                        print('this model is actually worse than the best one')

            # get the content_image batch
            content_image = data.get('coco').get('image')
            content_image = content_image.to(device)

            # get the style_image batch
            style_image = data.get('painter_by_numbers').get('image')
            style_image = style_image.to(device)

            style_feature_maps = encoder(style_image)['r41'].to(device)
            content_feature_maps = encoder(content_image)['r41'].to(device)

            result_feature_maps = torch.zeros(1, 1, 32, 32)

            content_feature_map_batch_loader = data_loader.get_batch(
                content_feature_maps, feature_map_batch_size)
            style_feature_map_batch_loader = data_loader.get_batch(
                style_feature_maps, feature_map_batch_size)

            # this is the inner training loop (feature maps)
            while True:
                try:
                    content_feature_map_batch = next(
                        content_feature_map_batch_loader).to(device)
                    style_feature_map_batch = next(
                        style_feature_map_batch_loader).to(device)
                    all_training_iteration += 1
                except StopIteration:
                    break
                except:
                    continue

                do_print = all_training_iteration % step_printing_interval == 0

                optimizer.zero_grad()

                style_feature_map_batch_moments = utils.compute_moments_batches(
                    style_feature_map_batch, last_moment=net_moment_mode)
                content_feature_map_batch_moments = utils.compute_moments_batches(
                    content_feature_map_batch, last_moment=net_moment_mode)

                out = moment_alignment_model(
                    content_feature_map_batch,
                    content_feature_map_batch_moments,
                    style_feature_map_batch_moments)

                loss, moment_loss, reconstruction_loss = criterion(
                    content_feature_map_batch,
                    style_feature_map_batch,
                    content_feature_map_batch_moments,
                    style_feature_map_batch_moments,
                    out,
                    last_moment=loss_moment_mode)

                if do_print:
                    set_lr(optimizer,
                           lr=schedule.get(all_training_iteration /
                                           step_printing_interval))
                    loss_writer.write_row([
                        epoch, all_training_iteration,
                        loss.item(),
                        moment_loss.item(),
                        reconstruction_loss.item()
                    ])

                # backprop
                loss.backward()
                optimizer.step()

                result_feature_maps = torch.cat(
                    [result_feature_maps,
                     out.cpu().view(1, -1, 32, 32)], 1)

                # if do_print:
                #     print('loss: {:4f}'.format(loss.item()))
                #
                #     writer.add_scalar('data/training_loss', loss.item(), all_training_iteration)
                #     writer.add_scalar('data/training_moment_loss', moment_loss.item(), all_training_iteration)
                #     writer.add_scalar('data/training_reconstruction_loss', reconstruction_loss.item(),
                #                       all_training_iteration)

            result_feature_maps = result_feature_maps[:, 1:513, :, :]
            result_img = decoder(result_feature_maps.to(device))

            if outer_training_iteration % image_saving_interval == 0:
                u.save_image([
                    data_loader.imnorm(content_image, unloader),
                    data_loader.imnorm(style_image, unloader),
                    data_loader.imnorm(result_img, None)
                ],
                             '{}/image_{}_{}__{}_{}.jpeg'.format(
                                 image_saving_path, epoch,
                                 outer_training_iteration / epoch, lambda_1,
                                 lambda_2),
                             normalize=False)

            # save every epoch_saving_interval the current model
            if outer_training_iteration % image_saving_interval == 0:
                utils.save_current_model(lambda_1, lambda_2,
                                         moment_alignment_model.state_dict(),
                                         optimizer.state_dict(),
                                         configuration['model_saving_path'])
        # TEST
        def test_step(x_batch, y_batch):
            feed_dict = {
                model.input_x: x_batch,
                model.input_y: y_batch,
                model.dropout_keep_prob: config.DROPOUT_PROB
            }
            step, summaries, loss, accuracy = sess.run(
                [global_step, dev_summary_op, model.loss, model.accuracy],
                feed_dict)
            time_str = datetime.datetime.now().isoformat()
            logger.info("{}: step {}, loss {:g}, acc {:g}".format(
                time_str, step, loss, accuracy))

        batches = data_loader.get_batch(list(zip(x_train, y_train)),
                                        config.BATCH_SIZE, config.NUM_EPOCHS)

        #TRAIN FOR EACH BATCH
        for batch in batches:
            x_batch, y_batch = zip(*batch)
            train_step(x_batch, y_batch)
            current_step = tf.train.global_step(sess, global_step)
            if current_step % config.EVALUATE_EVERY == 0:
                logger.info("\nEvaluation:")
                dev_step(x_dev, y_dev, writer=dev_summary_writer)
                logger.info("")
            if current_step % config.CHECKPOINT_EVERY == 0:
                path = saver.save(sess,
                                  checkpoint_prefix,
                                  global_step=current_step)
                logger.info("Saved model checkpoint to {}\n".format(path))