예제 #1
0
 def valid(self):
     images, labels = utility.inputs(train=False, batch_size=100, epoch=None)
     logits = self.__inference(images, keep_prob=1)
     eval_correct = self.__evaluation(logits, labels)
     sess = tf.Session()
     saver = tf.train.Saver()
     saver.restore(sess, tf.train.latest_checkpoint(utility.MODEL_DIR))
     coord = tf.train.Coordinator()
     threads = tf.train.start_queue_runners(sess=sess, coord=coord)
     try:
         num_iter = utility.VALID_SIZE/100
         true_count = 0
         total_true_count = 0
         total_sample_count = utility.VALID_SIZE
         step = 0
         while step < num_iter and not coord.should_stop():
             true_count = sess.run(eval_correct)
             total_true_count += true_count
             step += 1
         precision = total_true_count / total_sample_count
         print('正确数量/总数: %d/%d 正确率 = %.3f' % (total_true_count, total_sample_count, precision))
     except Exception as e:
         coord.request_stop(e)
     finally:
         coord.request_stop()
     coord.join(threads)
     sess.close()
예제 #2
0
    def train(self):
        if not os.path.exists(utility.LOG_DIR):
            os.mkdir(utility.LOG_DIR)
        if not os.path.exists(utility.MODEL_DIR):
            os.mkdir(utility.MODEL_DIR)

        step = 0
        images, labels = utility.inputs(train=True, batch_size=utility.BATCH_SIZE, epoch=90)
        logits = self.__inference(images, 0.5)
        loss = self.__loss(logits, labels)
        train_op = self.__training(loss)
        accuracy = self.__evaluation(logits, labels)
        saver = tf.train.Saver()
        summary_op = tf.summary.merge_all()
        with tf.Session() as sess:
            tf.global_variables_initializer().run()
            tf.local_variables_initializer().run()
            writer = tf.summary.FileWriter(utility.LOG_DIR, sess.graph)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            try:
                step = 0
                while not coord.should_stop():
                    start_time = time.time()
                    _, loss_value, performance, summaries = sess.run([train_op, loss, accuracy, summary_op])
                    duration = time.time() - start_time
                    if step % 10 == 0:
                        print('>> 已训练%d个批次: loss = %.2f (%.3f sec), 该批正确数量 = %d' % (step, loss_value, duration, performance))
                    if step % 100 == 0:
                        writer.add_summary(summaries, step)
                        saver.save(sess, utility.MODEL_DIR, global_step=step)
                    step += 1
            except tf.errors.OutOfRangeError:
                print('训练结束')
                saver.save(sess, utility.MODEL_DIR, global_step=step)
                coord.request_stop()
            finally:
                coord.request_stop()
            coord.join(threads)
예제 #3
0
import utility, re
"""Day 07: Handy Haversacks"""

inputs = utility.inputs(lambda x: re.sub(r' bags?|\.', '', x))
DELIMETER = ' contain |, '
TARGET = 'shiny gold'


def traverse(t, key, previous=set()):
    current = t.get(key, [])

    for node in current:
        previous.add(node)

    [traverse(t, node) for node in current]

    return len(previous)


def count_bags(t, name, qty=1):
    count, current = 0, t[name]

    for bag in current:
        if bag == 'no other': continue

        num_bags, next_bag = int(bag[0]), bag[2:]
        count += num_bags * (1 + count_bags(t, next_bag, num_bags))

    return count

예제 #4
0
import utility
"""Day 03: Toboggan Trajectory """

grid = utility.inputs()


def check_slope(grid, x, y):
    trees, pos = 0, 0
    for idx, row in enumerate(grid):
        if idx % y != 0:
            continue

        if row[pos % len(row)] == '#':
            trees += 1

        pos += x

    return trees


def part1():
    trees = check_slope(grid, 3, 1)

    return utility.solution({'trees': trees})


def part2():
    slopes = [(1, 1), (3, 1), (5, 1), (7, 1), (1, 2)]

    product = 1
    for x, y in slopes:
예제 #5
0
import utility
"""Day 05: Binary Boarding"""

inputs = utility.inputs()


def binary_partition(string, row=range(0, 128), column=range(0, 8)):
    def build_range(_range, is_lower):
        new_range = int(len(_range) / 2)
        if is_lower:
            return range(_range.start, _range.start + new_range)
        return range(_range.stop - new_range, _range.stop)

    if not string:
        return row.start, column.start

    command = string[0]

    if command in ('F', 'B'):
        row = build_range(row, command == 'F')
        return binary_partition(string[1:], row, column)

    col = build_range(column, command == 'L')
    return binary_partition(string[1:], row, col)


def part1():
    seat_id = lambda x, y: x * 8 + y

    seats = [seat_id(*binary_partition(ticket)) for ticket in inputs]
예제 #6
0
import utility
"""Day 09: Encoding Error"""

inputs = utility.inputs(lambda x: int(x))

BUFFER, SUM_TO = 25, 138879426
if utility.TEST_FLAG:
    BUFFER, SUM_TO = 5, 127


def part1():
    idx = BUFFER
    while idx < len(inputs):
        preamble = inputs[idx - BUFFER:idx]
        cmp = inputs[idx]

        found = False
        for x in preamble:
            if (cmp - x) in preamble:
                found = True

        if not found:
            return utility.solution({'val': cmp}, test=127)
        idx += 1


def part2():
    start = 0
    while True:
        idx, total, walk = start, 0, []
        while total < SUM_TO:
예제 #7
0
import utility
"""Day 2: Password Philosophy"""

inputs = utility.inputs(lambda x: x.split(' '))


def part1():
    valid = 0
    for bound, letter, password in inputs:
        count = password.count(letter[0])

        lower, upper = [int(b) for b in bound.split('-')]

        if count >= lower and count <= upper:
            valid += 1

    return utility.solution({'valid': valid})


def part2():
    valid = 0
    for bound, letter, password in inputs:
        lower, upper = [int(b) for b in bound.split('-')]

        lp, up = password[lower - 1], password[upper - 1]

        if (lp == letter[0]) ^ (up == letter[0]):
            valid += 1

    return utility.solution({'valid': valid})
예제 #8
0
import utility
"""Day 06: Custom Customs"""

inputs = utility.inputs(lambda x: x.split(), pre_process='\n\n')

distinct_answer = lambda group: set.union(*group)
distinct_group = lambda group: set.intersection(*group)

groups = [[set(ans) for ans in group] for group in inputs]


def part1():
    answers = map(len, map(distinct_answer, groups))

    return utility.solution({'sum': sum(answers)}, test=11)


def part2():
    answers = map(len, map(distinct_group, groups))

    return utility.solution({'sum': sum(answers)}, test=6)


if __name__ == '__main__':
    utility.cli()
예제 #9
0
import utility
"""Day 1: Report Repair"""

inputs = utility.inputs(parse=lambda x: int(x))


def part1():
    """Find two entries in the input file that add to 2020, report the product."""
    for e1 in inputs:
        for e2 in inputs:
            if e1 + e2 == 2020:
                return utility.solution({
                    'product': e1 * e2,
                    'numbers': (e1, e2)
                })


def part2():
    """Find three entries in the input file that add to 2020, report the product."""
    for e1 in inputs:
        for e2 in inputs:
            for e3 in inputs:
                if e1 + e2 + e3 == 2020:
                    return utility.solution({
                        'product': e1 * e2 * e3,
                        'numbers': (e1, e2, e3)
                    })


if __name__ == '__main__':
    utility.cli()
예제 #10
0
import utility
"""Day 04: Passport Processing"""

inputs = utility.inputs(parse=lambda line: [field for field in line.split()],
                        pre_process='\n\n')


def valid_height(x):
    is_cm = x.endswith('cm') and 150 <= int(x[:-2]) <= 193
    is_in = x.endswith('in') and 59 <= int(x[:-2]) <= 76
    return is_cm or is_in


FIELDS = {
    'byr':
    lambda x: 1920 <= int(x) <= 2002,
    'iyr':
    lambda x: 2010 <= int(x) <= 2020,
    'eyr':
    lambda x: 2020 <= int(x) <= 2030,
    'hgt':
    valid_height,
    'hcl':
    lambda x: x[0] == '#' and len(x) == 7 and all(c.isdigit() or c in 'abcdef'
                                                  for c in x[1:]),
    'ecl':
    lambda x: x in ('amb', 'blu', 'brn', 'gry', 'grn', 'hzl', 'oth'),
    'pid':
    lambda x: len(x) == 9 and all(c.isdigit() for c in x)
}
REQUIRED = set(FIELDS)