Exemplo n.º 1
0
    def init_rnn(self, grad=True, init_f=torch.rand):
        """

        :param grad:
        :param init_f: how to initialize the lstm
        :return:
        """
        return [
            _variable(
                init_f(self.num_layers, self.batch_size,
                       self.hidden_size).float()),
            _variable(
                init_f(self.num_layers, self.batch_size,
                       self.hidden_size).float())
        ]
Exemplo n.º 2
0
    def allocation(self, usage):
        """Computes allocation by sorting `usage`

            This corresponds to the value a = a_t[phi_t[j]] in the paper.

            Args:
                usage: tensor of shape `[batch_size, memory_size]` indicating current
                    memory usage. This is equal to u_t in the paper when we only have one
                    write head, but for multiple write heads, one should update the usage
                    while iterating through the write heads to take into account the
                    allocation returned by this function.

            Returns:
                Tensor of shape `[batch_size, memory_size]` corresponding to allocation.
            """
        # usage => [batch_size, memory_size]
        # Ensure values are not too small prior to cumprod.
        usage = eps + (1 - eps) * usage
        nonusage = 1 - usage

        sorted_nonusage, indices = nonusage.sort(-1, descending=True)
        _, perms = indices.sort(-1)

        sorted_usage = 1 - sorted_nonusage
        ones_ = _variable(torch.ones(usage.size(0), 1))
        x_base = torch.cat([ones_, sorted_usage], -1)

        prod_sorted_usage = x_base.cumprod(-1)
        sorted_allocation = sorted_nonusage * prod_sorted_usage[:, :-1]
        indexed = sorted_allocation.gather(-1, perms)

        return indexed
Exemplo n.º 3
0
    def _link(self, prev_link, prev_precedence_weights, write_weights):
        """Calculates the new link graphs.

            For each write head, the link is a directed graph (represented by a matrix
            with entries in range [0, 1]) whose vertices are the memory locations, and
            an edge indicates temporal ordering of writes.

            Args:
                prev_link:
                    A tensor of shape `[batch_size, num_writes, memory_size, memory_size]`
                    representing the previous link graphs for each write head.
                prev_precedence_weights:
                    A tensor of shape `[batch_size, num_writes, memory_size]`
                    which is the previous "aggregated" write weights for each write head.
                write_weights:
                    A tensor of shape `[batch_size, num_writes, memory_size]`
                    containing the new locations in memory written to.

            Returns:
                A tensor of shape `[batch_size, num_writes, memory_size, memory_size]`
                containing the new link graphs for each write head.
            """

        write_weights_i = write_weights.unsqueeze(3)
        write_weights_j = write_weights.unsqueeze(2)

        prev_precedence_weights_j = prev_precedence_weights.unsqueeze(2)
        prev_link_scale = 1 - write_weights_i - write_weights_j
        new_link = write_weights_i * prev_precedence_weights_j

        # scale old links, and add new links
        link = prev_link_scale * prev_link + new_link
        zeros = torch.LongTensor(list(range(self._memory_size)))
        zero_idxs = _variable(zeros.view(1, self._num_writes, -1, 1))
        return link.scatter_(-1, zero_idxs, 0)
Exemplo n.º 4
0
 def init_state(self, grad=True):
     """ prev_state: A `DNCState` tuple containing the fields
         `access_output` `[batch_size, num_reads, word_size]` containing read words.
         `access_state` is a tuple of the access module's state
         `controller_state` is a tuple of controller module's state
         """
     return [
         _variable(torch.zeros(self.batch_size, self.num_reads,
                               self.word_len),
                   requires_grad=grad),
         _variable(torch.zeros(self.batch_size, self.mem_size,
                               self.word_len),
                   requires_grad=grad),  #memory
         _variable(torch.zeros(self.batch_size, self.num_reads,
                               self.mem_size),
                   requires_grad=grad),  #read_weights
         _variable(torch.zeros(self.batch_size, self.num_writes,
                               self.mem_size),
                   requires_grad=grad),  #write weights
         _variable(torch.zeros(self.batch_size, self.num_writes,
                               self.mem_size, self.mem_size),
                   requires_grad=grad),  #linkage
         _variable(torch.zeros(self.batch_size, self.num_writes,
                               self.mem_size),
                   requires_grad=grad),  #linkage weight
         _variable(torch.zeros(self.batch_size, self.mem_size),
                   requires_grad=grad)
     ]  # usage
Exemplo n.º 5
0
 def init_state(self):
     return [
         _variable(torch.zeros(self.batch_size,
                               self.num_reads * self.out_size),
                   requires_grad=True),
         [
             _variable(torch.randn(self.num_layers, self.batch_size,
                                   self.hidden_size),
                       requires_grad=True),
             _variable(torch.randn(self.num_layers, self.batch_size,
                                   self.hidden_size),
                       requires_grad=True)
         ],
         [
             _variable(torch.randn(1, self.batch_size, self.out_size),
                       requires_grad=True),
             _variable(torch.randn(1, self.batch_size, self.out_size),
                       requires_grad=True)
         ]
     ]
Exemplo n.º 6
0
def action_loss(logits, action, criterion, log=None):
    """
        Sum of losses of one hot vectors encoding an action
        :param logits: network output vector of [action, [[type_i, ent_i], for i in ents]]
        :param action: target vector size [7]
        :param criterion: loss function
        :return:
        """
    losses = []
    for idx, action_part in enumerate(flat(action)):
        tgt = _variable(torch.LongTensor([action_part]))
        losses.append(criterion(logits[idx], tgt))
    loss = torch.stack(losses, 0).mean()
    if log is not None:
        sl.log_loss(losses, loss)
    return loss
Exemplo n.º 7
0
def action_loss_for_shortest_path(logits, action, current_state, criterion, log=None):
    from_e = "{0:03d}".format(current_state)
    to_e = "{0:03d}".format(action)
    a_str = from_e + to_e
    a_char_list = list(a_str)
    a_list = [int(_a_char) for _a_char in a_char_list]
    losses = []
    for idx, action_part in enumerate(a_list):
        if idx > 2:
            tgt = _variable(torch.LongTensor([action_part]))
            # print("logits: ", logits[idx])
            # print("tgt: ", tgt)
            # if idx == 5:
                # print(logits[idx].data.numpy()[0].tolist())
            losses.append(criterion(logits[idx], tgt))
    loss = torch.stack(losses, 0).mean()
    if log is not None:
        sl.log_loss(losses, loss)
    return loss
Exemplo n.º 8
0
def combined_ent_loss(logits, action, criterion, log=None):
    """
        some hand tunining of penalties for illegal actions...
            trying to force learning of types.

        action type => type_e...
        :param logits: network output vector of one_hot distributions
            [action, [type_i, ent_i], for i in ents]
        :param action: target vector size [7]
        :param criterion: loss function
        :return:
        """
    losses = []
    for idx, action_part in enumerate(flat(action)):
        tgt = _variable(torch.Tensor([action_part]).float())
        losses.append(criterion(logits[idx], tgt))
    lfs = [[losses[0]]]
    n = 2
    for l in(losses[i:i+n] for i in range(1, len(losses), n)):
        lfs.append(torch.stack(losses, 0).sum())
    loss = torch.stack(lfs, 0).mean()
    if log is not None:
        sl.log_loss(losses, loss)
    return loss
Exemplo n.º 9
0
def train_plan(args, data, DNC, lstm_state, optimizer):
    """
        Things to test after some iterations:
         - on planning phase and on

         with goals - chose a goal and work toward that
        :param args:
        :return:
        """
    criterion = nn.CrossEntropyLoss().cuda(
    ) if args.cuda is True else nn.CrossEntropyLoss()
    cum_correct, cum_total, prob_times, n_success = [], [], [], 0
    penalty = 1.1

    for trial in range(args.iters):
        start_prob = time.time()
        phase_masks = data.make_new_problem()
        n_total, n_correct, prev_action, loss, stats = 0, 0, None, 0, []
        dnc_state = DNC.init_state(grad=False)
        lstm_state = DNC.init_rnn(grad=False)  # lstm_state,
        optimizer.zero_grad()

        for phase_idx in phase_masks:

            if phase_idx == 0 or phase_idx == 1:
                inputs = _variable(data.getitem_combined())
                logits, dnc_state, lstm_state = DNC(inputs, lstm_state,
                                                    dnc_state)
                _, prev_action = data.strip_ix_mask(logits)

            elif phase_idx == 2:
                mask = _variable(data.getmask())
                inputs = torch.cat([mask, prev_action], 1)
                logits, dnc_state, lstm_state = DNC(inputs, lstm_state,
                                                    dnc_state)
                _, prev_action = data.strip_ix_mask(logits)

            else:
                # sample from best moves
                actions_star, all_actions = data.get_actions(mode='both')
                if not actions_star:
                    break
                if args.zero_at == 'step':
                    optimizer.zero_grad()

                mask = data.getmask()
                prev_action = prev_action.cuda(
                ) if args.cuda is True else prev_action
                pr = u.depackage(prev_action)

                final_inputs = _variable(torch.cat([mask, pr], 1))
                logits, dnc_state, lstm_state = DNC(final_inputs, lstm_state,
                                                    dnc_state)
                exp_logits = data.ix_input_to_ixs(logits)

                guided = random.random() < args.beta
                # thing 1
                if guided:  # guided loss
                    final_action, lstep = L.naive_loss(exp_logits,
                                                       actions_star,
                                                       criterion,
                                                       log=True)
                else:  # pick own move
                    final_action, lstep = L.naive_loss(exp_logits,
                                                       all_actions,
                                                       criterion,
                                                       log=True)

                # penalty for todo tests this !!!!
                action_own = u.get_prediction(exp_logits)
                if args.penalty and not [tuple(flat(t)) for t in all_actions]:
                    final_loss = lstep * _variable([args.penalty])
                else:
                    final_loss = lstep

                if args.opt_at == 'problem':
                    loss += final_loss
                else:

                    final_loss.backward(retain_graph=args.ret_graph)
                    if args.clip:
                        torch.nn.utils.clip_grad_norm(DNC.parameters(),
                                                      args.clip)
                    optimizer.step()
                    loss = lstep

                data.send_action(final_action)

                if (trial + 1) % args.show_details == 0:
                    action_accs = u.human_readable_res(data, all_actions,
                                                       actions_star,
                                                       action_own, guided,
                                                       lstep.data[0])
                    stats.append(action_accs)
                n_total, _ = tick(n_total, n_correct, action_own,
                                  flat(final_action))
                n_correct += 1 if action_own in [
                    tuple(flat(t)) for t in actions_star
                ] else 0
                prev_action = data.vec_to_ix(final_action)

        if stats:
            arr = np.array(stats)
            correct = len([
                1 for i in list(arr.sum(axis=1)) if i == len(stats[0])
            ]) / len(stats)
            sl.log_acc(list(arr.mean(axis=0)), correct)

        if args.opt_at == 'problem':
            floss = loss / n_total
            floss.backward(retain_graph=args.ret_graph)
            if args.clip:
                torch.nn.utils.clip_grad_norm(DNC.parameters(), args.clip)
            optimizer.step()
            sl.writer.add_scalar('losses.end', floss.data[0], sl.global_step)

        n_success += 1 if n_correct / n_total > args.passing else 0
        cum_total.append(n_total)
        cum_correct.append(n_correct)
        sl.add_scalar('recall.pct_correct', n_correct / n_total,
                      sl.global_step)
        print(
            "trial {}, step {} trial accy: {}/{}, {:0.2f}, running total {}/{}, running avg {:0.4f}, loss {:0.4f}  "
            .format(trial, sl.global_step, n_correct, n_total,
                    n_correct / n_total, n_success, trial,
                    running_avg(cum_correct, cum_total), loss.data[0]))
        end_prob = time.time()
        prob_times.append(start_prob - end_prob)
    print("solved {} out of {} -> {}".format(n_success, args.iters,
                                             n_success / args.iters))
    return DNC, optimizer, lstm_state, running_avg(cum_correct, cum_total)
Exemplo n.º 10
0
def train_qa2(args, data, DNC, optimizer):
    """
        I am jacks liver. This is a sanity test

        0 - describe state.
        1 - describe goal.
        2 - do actions.
        3 - ask some questions
        :param args:
        :return:
        """
    criterion = nn.CrossEntropyLoss()
    cum_correct, cum_total = [], []

    for trial in range(args.iters):
        phase_masks = data.make_new_problem()
        n_total, n_correct, loss = 0, 0, 0
        dnc_state = DNC.init_state(grad=False)
        optimizer.zero_grad()

        for phase_idx in phase_masks:
            if phase_idx == 0 or phase_idx == 1:
                inputs = _variable(data.getitem_combined())
                logits, dnc_state = DNC(inputs, dnc_state)
            else:
                final_moves = data.get_actions(mode='one')
                if final_moves == []:
                    break
                data.send_action(final_moves[0])
                mask = data.phase_oh[2].unsqueeze(0)
                inputs2 = _variable(
                    torch.cat([mask, data.vec_to_ix(final_moves[0])], 1))
                logits, dnc_state = DNC(inputs2, dnc_state)

                for _ in range(args.num_tests):
                    # ask where is ---?
                    if args.zero_at == 'step':
                        optimizer.zero_grad()
                    masked_input, mask_chunk, ground_truth = data.masked_input(
                    )
                    logits, dnc_state = DNC(_variable(masked_input), dnc_state)
                    expanded_logits = data.ix_input_to_ixs(logits)

                    # losses
                    lstep = L.action_loss(expanded_logits,
                                          ground_truth,
                                          criterion,
                                          log=True)
                    if args.opt_at == 'problem':
                        loss += lstep
                    else:
                        lstep.backward(retain_graph=args.ret_graph)
                        optimizer.step()
                        loss = lstep

                    # update counters
                    prediction = u.get_prediction(expanded_logits, [3, 4])
                    n_total, n_correct = tick(n_total, n_correct, mask_chunk,
                                              prediction)

        if args.opt_at == 'problem':
            loss.backward(retain_graph=args.ret_graph)
            optimizer.step()
            sl.writer.add_scalar('losses.end', loss.data[0], sl.global_step)

        cum_total.append(n_total)
        cum_correct.append(n_correct)
        sl.writer.add_scalar('recall.pct_correct', n_correct / n_total,
                             sl.global_step)
        print(
            "trial: {}, step:{}, accy {:0.4f}, cum_score {:0.4f}, loss: {:0.4f}"
            .format(trial, sl.global_step, n_correct / n_total,
                    running_avg(cum_correct, cum_total), loss.data[0]))
    return DNC, optimizer, dnc_state, running_avg(cum_correct, cum_total)
Exemplo n.º 11
0
def train_shortest_path_plan(args, data, DNC, lstm_state, optimizer):
    criterion = nn.CrossEntropyLoss().cuda(
    ) if args.cuda is True else nn.CrossEntropyLoss()
    cum_correct, cum_total, prob_times, n_success = [], [], [], 0
    penalty = 1.1

    for trial in range(args.iters):
        if trial % 100 == 0:
            ___score, ___goal_score = test_shortest_path_planning(
                args, data, DNC, lstm_state)
        start_prob = time.time()
        phase_masks = data.make_new_graph()
        print("Shortest Path :: ", data.shortest_path)
        n_total, n_correct, prev_action, loss, stats = 0, 0, None, 0, []
        dnc_state = DNC.init_state(grad=False)
        lstm_state = DNC.init_rnn(grad=False)  # lstm_state,
        optimizer.zero_grad()

        for phase_idx in phase_masks:

            if phase_idx == 0 or phase_idx == 1:
                inputs = _variable(data.getitem_combined())
                logits, dnc_state, lstm_state = DNC(inputs, lstm_state,
                                                    dnc_state)
                _, prev_action = data.strip_ix_mask(logits)

            elif phase_idx == 2:
                mask = _variable(data.getmask())
                inputs = torch.cat([mask, prev_action], 1)
                logits, dnc_state, lstm_state = DNC(inputs, lstm_state,
                                                    dnc_state)
                _, prev_action = data.strip_ix_mask(logits)

            else:
                best_nodes, all_nodes = data.get_actions()
                if not best_nodes:
                    break
                if args.zero_at == 'step':
                    optimizer.zero_grad()

                mask = data.getmask()
                prev_action = prev_action.cuda(
                ) if args.cuda is True else prev_action
                # print("previous action: ", prev_action)
                pr = u.depackage(prev_action)

                final_inputs = _variable(torch.cat([mask, pr], 1))
                logits, dnc_state, lstm_state = DNC(final_inputs, lstm_state,
                                                    dnc_state)
                exp_logits = data.ix_input_to_ixs(logits)
                current_state = data.STATE
                guided = random.random() < args.beta
                sup_flag = None
                if guided:  # guided loss
                    final_action, lstep = L.naive_loss_for_shortest_path(
                        exp_logits,
                        best_nodes,
                        current_state,
                        criterion,
                        log=True)
                    sup_flag = "Yes"
                else:  # pick own move
                    final_action, lstep = L.naive_loss_for_shortest_path(
                        exp_logits,
                        all_nodes,
                        current_state,
                        criterion,
                        log=True)
                    sup_flag = "No"
                action_own = u.get_prediction(exp_logits)

                final_loss = lstep
                final_loss.backward(retain_graph=args.ret_graph)
                if args.clip:
                    torch.nn.utils.clip_grad_norm(DNC.parameters(), args.clip)
                optimizer.step()
                loss = lstep
                print(
                    "Supervised: " + sup_flag + ", " +
                    str(data.current_index) + " index, from: " +
                    str(current_state) + ", to: " + str(final_action) +
                    ", loss: ", final_loss.data[0])

                data.STATE = final_action

                prev_action = torch.from_numpy(
                    np.array(data.vec_to_ix([current_state,
                                             final_action])).reshape(
                                                 (1, 61))).float()

        #### under experiment ####
        goal_loss = L.action_loss_for_shortest_path(exp_logits,
                                                    data.goal,
                                                    current_state,
                                                    criterion,
                                                    log=True)
        goal_loss.backward(retain_graph=args.ret_graph)
        optimizer.step()
        print("Goal Loss: ", goal_loss.data[0])
        ####
        end_prob = time.time()
        prob_times.append(start_prob - end_prob)
    # print("solved {} out of {} -> {}".format(n_success, args.iters, n_success / args.iters))
    return DNC, optimizer, lstm_state, 0.  # running_avg(cum_correct, cum_total)
Exemplo n.º 12
0
def test_shortest_path_planning(args, data, DNC, lstm_state):
    criterion = nn.CrossEntropyLoss().cuda(
    ) if args.cuda is True else nn.CrossEntropyLoss()
    cum_correct, cum_total, prob_times, n_success = [], [], [], 0
    action_score = 0
    goal_score = 0
    total_actions = 0
    total_problems = 2
    print("\n")
    print("Accuracy Test is started ....")
    output_dict = {}
    for i in range(total_problems):
        start_prob = time.time()
        phase_masks = data.make_new_graph()
        print("Shortest Path :: ", data.shortest_path)
        n_total, n_correct, prev_action, loss, stats = 0, 0, None, 0, []
        dnc_state = DNC.init_state(grad=False)
        lstm_state = DNC.init_rnn(grad=False)  # lstm_state,

        input_history = []
        access_output_v = []
        m = []
        rw = []
        ww = []
        l = []
        lw = []
        uuu = []

        for phase_idx in phase_masks:

            if phase_idx == 0 or phase_idx == 1:
                inputs = _variable(data.getitem_combined())
                logits, dnc_state, lstm_state = DNC(inputs, lstm_state,
                                                    dnc_state)
                _, prev_action = data.strip_ix_mask(logits)
                input_history.append(inputs.data[0].numpy().tolist())
                access_output_v.append(
                    torch.squeeze(dnc_state[0].data).numpy().tolist())
                m.append(torch.squeeze(dnc_state[1].data).numpy().tolist())
                rw.append(torch.squeeze(dnc_state[2].data).numpy().tolist())
                ww.append(torch.squeeze(dnc_state[3].data).numpy().tolist())
                l.append(torch.squeeze(dnc_state[4].data).numpy().tolist())
                lw.append(torch.squeeze(dnc_state[5].data).numpy().tolist())
                uuu.append(torch.squeeze(dnc_state[6].data).numpy().tolist())

            elif phase_idx == 2:
                mask = _variable(data.getmask())
                inputs = torch.cat([mask, prev_action], 1)
                logits, dnc_state, lstm_state = DNC(inputs, lstm_state,
                                                    dnc_state)
                _, prev_action = data.strip_ix_mask(logits)
                input_history.append(inputs.data[0].numpy().tolist())
                access_output_v.append(
                    torch.squeeze(dnc_state[0].data).numpy().tolist())
                m.append(torch.squeeze(dnc_state[1].data).numpy().tolist())
                rw.append(torch.squeeze(dnc_state[2].data).numpy().tolist())
                ww.append(torch.squeeze(dnc_state[3].data).numpy().tolist())
                l.append(torch.squeeze(dnc_state[4].data).numpy().tolist())
                lw.append(torch.squeeze(dnc_state[5].data).numpy().tolist())
                uuu.append(torch.squeeze(dnc_state[6].data).numpy().tolist())

            else:
                best_nodes, all_nodes = data.get_actions()
                if not best_nodes:
                    break
                mask = data.getmask()
                prev_action = prev_action.cuda(
                ) if args.cuda is True else prev_action
                pr = u.depackage(prev_action)

                final_inputs = _variable(torch.cat([mask, pr], 1))
                logits, dnc_state, lstm_state = DNC(final_inputs, lstm_state,
                                                    dnc_state)
                input_history.append(final_inputs.data[0].numpy().tolist())
                access_output_v.append(
                    torch.squeeze(dnc_state[0].data).numpy().tolist())
                m.append(torch.squeeze(dnc_state[1].data).numpy().tolist())
                rw.append(torch.squeeze(dnc_state[2].data).numpy().tolist())
                ww.append(torch.squeeze(dnc_state[3].data).numpy().tolist())
                l.append(torch.squeeze(dnc_state[4].data).numpy().tolist())
                lw.append(torch.squeeze(dnc_state[5].data).numpy().tolist())
                uuu.append(torch.squeeze(dnc_state[6].data).numpy().tolist())
                exp_logits = data.ix_input_to_ixs(logits)
                current_state = data.STATE
                guided = random.random() < args.beta
                final_action, lstep = L.naive_loss_for_shortest_path(
                    exp_logits, all_nodes, current_state, criterion, log=True)
                print(
                    str(data.current_index) + " index, from: " +
                    str(current_state) + ", to: " + str(final_action))
                if final_action in best_nodes:
                    action_score = action_score + 1
                total_actions = total_actions + 1
                action_own = u.get_prediction(exp_logits)
                data.STATE = final_action
                prev_action = torch.from_numpy(
                    np.array(data.vec_to_ix([current_state,
                                             final_action])).reshape(
                                                 (1, 61))).float()
        if data.goal == final_action:
            goal_score = goal_score + 1
        end_prob = time.time()
        prob_times.append(start_prob - end_prob)
        p_output_dict = {}
        p_output_dict["inputs"] = input_history
        p_output_dict["access_outputs"] = access_output_v
        p_output_dict["m"] = m
        p_output_dict["rw"] = rw
        p_output_dict["ww"] = ww
        p_output_dict["l"] = l
        p_output_dict["lw"] = lw
        p_output_dict["u"] = uuu
        p_output_dict["phases"] = torch.squeeze(phase_masks).numpy().tolist()
        output_dict[i] = p_output_dict
    action_score = float(action_score) / float(total_actions)
    goal_score = float(goal_score) / float(total_problems)
    print("Accuracy Test is ended .... " + "action score: " +
          str(action_score) + ", goal score: " + str(goal_score))
    print("\n")
    with open("output.json", "w") as f:
        json.dump(output_dict,
                  f,
                  ensure_ascii=False,
                  indent=4,
                  sort_keys=True,
                  separators=(',', ': '))
    return action_score, goal_score
Exemplo n.º 13
0
def inference(examples):

  # CONV1
  with tf.variable_scope('conv1') as scope:
    # conv weights [filter_height, filter_width, filter_depth, num_filters]
    kernel = _variable('weights', [CONV1_HEIGHT, CONV1_WIDTH, 1, CONV1_FILTERS], tf.contrib.layers.xavier_initializer_conv2d())
    biases = _variable('biases', [CONV1_FILTERS], tf.constant_initializer(0.1))

    conv = conv2d(examples, kernel)
    conv1 = tf.nn.relu(conv + biases, name=scope.name)
    _activation_summary(conv1)

  # pool1 dim: [n, time, freq after pooling, num_filters]
  pool1 = tf.nn.max_pool(conv1, ksize=[1, POOL1_HEIGHT, POOL1_WIDTH, 1], 
    strides=[1, POOL1_STRIDE_HEIGHT, POOL1_STRIDE_WIDTH, 1], padding='SAME', name='pool1')

  ## TODO: add batch norm 1 here
  batch_norm1_object = BatchNorm(name='batch_norm1', shape=[CONV1_FILTERS])
  batch_norm1 = batch_norm1_object(pool1)

  # CONV2
  with tf.variable_scope('conv2') as scope:
    kernel = _variable('weights', [CONV2_HEIGHT, CONV2_WIDTH, CONV1_FILTERS, CONV2_FILTERS], tf.contrib.layers.xavier_initializer_conv2d())
    biases = _variable('biases', [CONV2_FILTERS], tf.constant_initializer(0.1))

    conv = conv2d(batch_norm1, kernel)
    conv2 = tf.nn.relu(conv + biases, name=scope.name)
    _activation_summary(conv2)

  # POOL2
  pool2 = tf.nn.max_pool(conv2, ksize=[1, POOL2_HEIGHT, POOL2_WIDTH, 1], 
    strides=[1, POOL2_STRIDE_HEIGHT, POOL2_STRIDE_WIDTH, 1], padding='SAME', name='pool2')

  ## TODO: add batch norm 2 here
  batch_norm2_object = BatchNorm(name='batch_norm2', shape=[CONV2_FILTERS])
  batch_norm2 = batch_norm2_object(pool2)

  # FC3
  with tf.variable_scope('fc3') as scope:
    reshape = tf.reshape(batch_norm2, [BATCH_SIZE, -1])
    dim = (DIM_TIME/POOL1_HEIGHT/POOL2_HEIGHT) * (DIM_FREQ/POOL1_WIDTH/POOL2_WIDTH) * CONV2_FILTERS
    weights = _variable('weights', [dim, FC3_SIZE], tf.contrib.layers.xavier_initializer(), wd=config.fc_wd)
    biases = _variable('biases', [FC3_SIZE], tf.constant_initializer(0.1))

    fc3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)
    _activation_summary(fc3)

  # FC4
  with tf.variable_scope('fc4') as scope:
    weights = _variable('weights', [FC3_SIZE, FC4_SIZE], tf.contrib.layers.xavier_initializer(), wd=config.fc_wd)
    biases = _variable('biases', [FC4_SIZE], tf.constant_initializer(0.1))

    fc4 = tf.nn.relu(tf.matmul(fc3, weights) + biases, name=scope.name)
    _activation_summary(fc4)

  # FC5
  with tf.variable_scope('fc5') as scope:
    weights = _variable('weights', [FC4_SIZE, FC5_SIZE], tf.contrib.layers.xavier_initializer(), wd=config.fc_wd)
    biases = _variable('biases', [FC5_SIZE], tf.constant_initializer(0.1))

    fc5 = tf.nn.relu(tf.matmul(fc4, weights) + biases, name=scope.name)
    _activation_summary(fc5)

  # FC6
  with tf.variable_scope('fc6') as scope:
    weights = _variable('weights', [FC5_SIZE, FC6_SIZE], tf.contrib.layers.xavier_initializer(), wd=config.fc_wd)
    biases = _variable('biases', [FC6_SIZE], tf.constant_initializer(0.1))

    fc6 = tf.nn.relu(tf.matmul(fc5, weights) + biases, name=scope.name)
    _activation_summary(fc6)

  # softmax
  with tf.variable_scope('softmax_linear') as scope:
    weights = _variable('weights', [FC6_SIZE, NUM_CLASSES], tf.contrib.layers.xavier_initializer())
    biases = _variable('biases', [NUM_CLASSES], tf.constant_initializer(0.0))
    # shape of y_conv is (N,3)
    softmax_linear = tf.add(tf.matmul(fc6, weights), biases, name=scope.name)
    _activation_summary(softmax_linear)
  return softmax_linear