-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_neuron.py
88 lines (70 loc) · 2.47 KB
/
model_neuron.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from util import log
from decoder_neuron import Decoder_Neuron
class Model(object):
def __init__(self, config, debug_information=False, is_train=True):
self.debug = debug_information
self.config = config
self.batch_size = config.batch_size
self.l = config.l
self.output_dim = config.output_dim
self.output_act_fn = config.output_act_fn
self.num_d_fc = config.num_d_fc
self.d_norm_type = config.d_norm_type
self.loss_type = config.loss_type
# added for Decoder_mdl
self.load_pretrained = config.load_pretrained
self.arch = config.arch
# create placeholders for the input
self.activity = tf.placeholder(
name='activity', dtype=tf.float32,
shape=[self.batch_size, self.l],
)
self.label = tf.placeholder(
name='label', dtype=tf.float32,
shape=[self.batch_size, self.output_dim],
)
self.build(is_train=is_train)
def get_feed_dict(self, batch_chunk):
fd = {
self.activity: batch_chunk['activity'], # [bs, h, w, c]
self.label: batch_chunk['label'], # [bs, v] (v should be 3)
}
return fd
def build(self, is_train=True):
# Decoder {{{
# =========
# Input: an activity [bs, v]
# Output: [bs, [x, y, v]]
D = Decoder_Neuron('Decoder_Neuron', self.output_dim, self.output_act_fn,
self.num_d_fc, self.d_norm_type, is_train)
pred_label = D(self.activity)
self.pred_label = pred_label
# }}}
# Build losses {{{
# =========
# compute loss
if self.loss_type == 'l1':
self.ori_loss = tf.abs(self.label - pred_label)
self.loss = tf.reduce_mean(self.ori_loss)
elif self.loss_type == 'l2':
self.ori_loss = (self.label - pred_label) **2
self.loss = tf.reduce_mean(self.ori_loss)
else:
raise NotImplementedError
# }}}
# TensorBoard summaries {{{
# =========
tf.summary.scalar("loss/loss", self.loss)
# }}}
# Output {{{
# =========
self.output = {
'pred_label': pred_label
}
# }}}
log.warn('\033[93mSuccessfully loaded the model.\033[0m')