Example #1
0
File: DQN.py Project: peiwangdb/JOS
 def optimize_model(self,):
     import time
     startTime = time.time()
     samples = self.Memory.sample(64)
     value_now_list = []
     next_value_list = []
     if (len(samples)==0):
         return
     fold = torchfold.Fold(cuda=True)
     nowL = []
     for one_sample in samples:
         nowList = one_sample.env.selectValueFold(fold)
         nowL.append(len(nowList))
         value_now_list+=nowList
     res = fold.apply(self.policy_net, [value_now_list])[0]
     total = 0
     value_now_list = []
     next_value_list = []
     for idx,one_sample in enumerate(samples):
         value_now_list.append(self.policy_net.logits(res[total:total+nowL[idx]] , one_sample.env.sel.join_matrix ))
         next_value_list.append(one_sample.next_value)
         total += nowL[idx]
     value_now = torch.cat(value_now_list,dim = 0)
     next_value = torch.cat(next_value_list,dim = 0)
     endTime = time.time()
     if True:
         loss = F.smooth_l1_loss(value_now,next_value,size_average=True)
         self.optimizer.zero_grad()
         loss.backward()
         self.optimizer.step()
         return loss.item()
     return None
Example #2
0
    def forward(self, state, action):
        self.clear_buffer()
        if not self.disable_fold:
            self.fold = torchfold.Fold()
            self.fold.cuda()
            self.zeroFold_td = self.fold.add("zero_func_td")
            self.zeroFold_bu = self.fold.add("zero_func_bu")
            self.x1_fold, self.x2_fold = [], []
        assert state.shape[
            1] == self.state_dim * self.num_limbs, 'state.shape[1] expects {} but got {} with num_limbs being {} and state_dim being {}'.format(
                self.state_dim * self.num_limbs, state.shape[1],
                self.num_limbs, self.state_dim)
        for i in range(self.num_limbs):
            self.input_state[i] = state[:, i * self.state_dim:(i + 1) *
                                        self.state_dim]
            self.input_action[i] = action[:, i]
            self.input_action[i] = torch.unsqueeze(self.input_action[i], -1)
            if not self.disable_fold:
                self.input_state[i] = torch.unsqueeze(self.input_state[i], 0)
                self.input_action[i] = torch.unsqueeze(self.input_action[i], 0)

        if self.bu:
            # bottom up transmission by recursion
            for i in range(self.num_limbs):
                self.bottom_up_transmission(i)

        if self.td:
            # top down transmission by recursion
            for i in range(self.num_limbs):
                self.top_down_transmission(i)

        if not self.bu and not self.td:
            for i in range(self.num_limbs):
                if not self.disable_fold:
                    self.x1[i], self.x2[i] = self.fold.add(
                        'critic' + str(0).zfill(3), self.input_state[i],
                        self.input_action[i]).split(2)
                else:
                    self.x1[i], self.x2[i] = self.critic[i](
                        self.input_state[i], self.input_action[i])

        if not self.disable_fold:
            if self.bu and not self.td:
                self.x1_fold = self.x1_fold + [self.x1]
                self.x2_fold = self.x2_fold + [self.x2]
            else:
                self.x1_fold = self.x1_fold + self.x1
                self.x2_fold = self.x2_fold + self.x2
            self.x1, self.x2 = self.fold.apply(self,
                                               [self.x1_fold, self.x2_fold])
            self.x1 = torch.transpose(self.x1, 0, 1)
            self.x2 = torch.transpose(self.x2, 0, 1)
            self.fold = None
        else:
            self.x1 = torch.stack(self.x1, dim=-1)  # (bs,num_limbs,1)
            self.x2 = torch.stack(self.x2, dim=-1)

        return torch.sum(self.x1, dim=-1), torch.sum(self.x2, dim=-1)
Example #3
0
    def forward(self, state, mode="train"):
        self.clear_buffer()
        if mode == "inference":
            temp = self.batch_size
            self.batch_size = 1
        if not self.disable_fold:
            self.fold = torchfold.Fold()
            self.fold.cuda()
            self.zeroFold_td = self.fold.add("zero_func_td")
            self.zeroFold_bu = self.fold.add("zero_func_bu")
            self.a = []
        assert (
            state.shape[1] == self.state_dim * self.num_limbs
        ), "state.shape[1] expects {} but got {} with num_limbs being {} and state_dim being {}".format(
            self.state_dim * self.num_limbs,
            state.shape[1],
            self.num_limbs,
            self.state_dim,
        )

        for i in range(self.num_limbs):
            self.input_state[i] = state[:, i * self.state_dim:(i + 1) *
                                        self.state_dim]
            if not self.disable_fold:
                self.input_state[i] = torch.unsqueeze(self.input_state[i], 0)

        if self.bu:
            # bottom up transmission by recursion
            for i in range(self.num_limbs):
                self.bottom_up_transmission(i)

        if self.td:
            # top down transmission by recursion
            for i in range(self.num_limbs):
                self.top_down_transmission(i)

        if not self.bu and not self.td:
            for i in range(self.num_limbs):
                if not self.disable_fold:
                    self.action[i] = self.fold.add("actor" + str(0).zfill(3),
                                                   self.input_state[i])
                else:
                    self.action[i] = self.actor[i](self.input_state[i])

        if not self.disable_fold:
            self.a += self.action
            self.action = self.fold.apply(self, [self.a])[0]
            self.action = torch.transpose(self.action, 0, 1)
            self.fold = None
        else:
            self.action = torch.stack(self.action, dim=-1)
            self.msg_down = torch.stack(self.msg_down, dim=-1)

        if mode == "inference":
            self.batch_size = temp

        return torch.squeeze(self.action)
    def test_rnn(self):
        f = torchfold.Fold()
        v1, _ = f.add('value2', 1).split(2)
        v2, _ = f.add('value2', 2).split(2)
        r = v1
        for i in range(1000):
            r = f.add('attr', v1, v2)
            r = f.add('attr', r, v2)

        te = TestEncoder()
        enc = f.apply(te, [[r]])
        self.assertEqual(enc[0].size(), (1, 10))
Example #5
0
    def Q1(self, state, action):
        self.clear_buffer()
        if not self.disable_fold:
            self.fold = torchfold.Fold()
            self.fold.cuda()
            self.zeroFold_td = self.fold.add("zero_func_td")
            self.zeroFold_bu = self.fold.add("zero_func_bu")
            self.x1_fold = []

        for i in range(self.num_limbs):
            self.input_state[i] = state[:, i * self.state_dim:(i + 1) *
                                        self.state_dim]
            self.input_action[i] = action[:, i]
            self.input_action[i] = torch.unsqueeze(self.input_action[i], -1)
            if not self.disable_fold:
                self.input_state[i] = torch.unsqueeze(self.input_state[i], 0)
                self.input_action[i] = torch.unsqueeze(self.input_action[i], 0)

        if self.bu:
            # bottom up transmission by recursion
            for i in range(self.num_limbs):
                self.bottom_up_transmission(i)

        if self.td:
            # top down transmission by recursion
            for i in range(self.num_limbs):
                self.top_down_transmission(i)

        if not self.bu and not self.td:
            for i in range(self.num_limbs):
                if not self.disable_fold:
                    self.x1[i] = self.fold.add(
                        "critic" + str(0).zfill(3),
                        self.input_state[i],
                        self.input_action[i],
                    )
                else:
                    self.x1[i] = self.critic[i](self.input_state[i],
                                                self.input_action[i])

        if not self.disable_fold:
            if self.bu and not self.td:
                self.x1 = [self.x1]
            self.x1_fold = self.x1_fold + self.x1
            self.x1 = self.fold.apply(self, [self.x1_fold])[0]
            if not self.bu and not self.td:
                self.x1 = self.x1[0]
            self.x1 = torch.transpose(self.x1, 0, 1)
            self.fold = None
        else:
            self.x1 = torch.stack(self.x1, dim=-1)  # (bs,num_limbs,1)

        return torch.sum(self.x1, dim=-1)
    def test_nobatch(self):
        f = torchfold.Fold()
        v = []
        for i in range(15):
            v.append(f.add('value', i % 10))
        d = f.add('concat', *v).nobatch()
        res = []
        for i in range(100):
            res.append(f.add('logits', v[i % 10], d))

        te = TestEncoder()
        enc = f.apply(te, [res])
        self.assertEqual(len(enc), 1)
        self.assertEqual(enc[0].size(), (100, 15))
def main():
    inputs = datasets.snli.ParsedTextField(lower=True)
    transitions = datasets.snli.ShiftReduceField()
    answers = data.Field(sequential=False)

    train, dev, test = datasets.SNLI.splits(inputs, answers, transitions)
    inputs.build_vocab(train, dev, test)
    answers.build_vocab(train)
    train_iter, dev_iter, test_iter = data.BucketIterator.splits(
        (train, dev, test),
        batch_size=args.batch_size,
        device=0 if args.cuda else -1)

    model = SPINN(3, 500, 1000)
    criterion = nn.CrossEntropyLoss()
    opt = optim.Adam(model.parameters(), lr=0.01)

    for epoch in range(10):
        start = time.time()
        iteration = 0
        for batch_idx, batch in enumerate(train_iter):
            opt.zero_grad()

            all_logits, all_labels = [], []
            fold = torchfold.Fold(cuda=args.cuda)
            for example in batch.dataset:
                tree = Tree(example, inputs.vocab, answers.vocab)
                if args.fold:
                    all_logits.append(encode_tree_fold(fold, tree))
                else:
                    all_logits.append(encode_tree_regular(model, tree))
                all_labels.append(tree.label)

            if args.fold:
                res = fold.apply(model, [all_logits, all_labels])
                loss = criterion(res[0], res[1])
            else:
                loss = criterion(torch.cat(all_logits, 0),
                                 Variable(torch.LongTensor(all_labels)))
            loss.backward()
            opt.step()

            iteration += 1
            if iteration % 10 == 0:
                print("Avg. Time: %fs" % ((time.time() - start) / iteration))
    def test_rnn_optimized_chunking(self):
        seq_lengths = [2, 3, 5]

        states = []
        for _ in xrange(len(seq_lengths)):
            states.append(self._generate_variable(self.num_units))

        f = torchfold.Fold()
        for seq_ind in xrange(len(seq_lengths)):
            for _ in xrange(seq_lengths[seq_ind]):
                states[seq_ind] = f.add(
                    'encode', self._generate_variable(self.input_size),
                    states[seq_ind])

        enc = RNNEncoder(self.num_units, self.input_size)
        with mock.patch.object(torch, 'chunk',
                               wraps=torch.chunk) as wrapped_chunk:
            result = f.apply(enc, [states])
            # torch.chunk is called 3 times instead of max(seq_lengths)=5.
            self.assertEquals(3, wrapped_chunk.call_count)
        self.assertEqual(len(result), 1)
        self.assertEqual(result[0].size(), (len(seq_lengths), self.num_units))
Example #9
0
    def compute_loss_annotated(self, batch):
        batch_size = batch.batch_size
        if self.args.cuda:
            batch = batch.cuda_train()

        initial_state, memory, initial_logits = self.model.prepare_initial(
            batch.input_grids, batch.output_grids, batch.current_grids,
            batch.current_code)
        max_code_length = memory.current_code.memory.shape[2]
        initial_state_orig = initial_state
        # state before: (batch x num pairs x hidden,
        #                num layers x batch x num pairs x hidden,
        #                num layers x batch x num pairs x hidden)
        # state after: list of (1 x num pairs x hidden,
        #                       1 x num layers x num pairs x hidden,
        #                       1 x num layers x num pairs x hidden)
        initial_state = zip(
            torch.chunk(initial_state.context, batch_size),
            torch.chunk(initial_state.h.permute(1, 0, 2, 3), batch_size),
            torch.chunk(initial_state.c.permute(1, 0, 2, 3), batch_size))
        # memory: io, current_grid, current_code.memory, current_code.attn_mask
        # before: (batch x num pairs x 512,
        #          batch x num pairs x 256,
        #          batch x num pairs x max code length x 512,
        #          batch x num pairs x max code length)
        # after: list of (1 x num pairs x 512,
        #                 1 x num pairs x 256,
        #                 1 x num pairs x max code length x 512,
        #                 1 x num pairs x max code length)
        memory = zip(*(torch.chunk(t, batch_size) for t in memory.to_flat()))
        initial_logits = torch.chunk(initial_logits, batch_size)

        fold = torchfold.Fold(cuda=self.args.cuda)
        #fold = torchfold.Unfold(nn=self.model, cuda=self.args.cuda)
        zero = fold.add('tf_torch_zero')
        log_probs = []
        for batch_idx, allowed_edits in enumerate(batch.allowed_edits):
            item_log_probs = []
            item_memory = memory[batch_idx]

            # before: 1 x num pairs x length x hidden size
            # after: 1 x length x hidden size
            current_code_memory = item_memory[2][:, 0]
            current_code_attn_mask = item_memory[3][:, 0]

            def step(state, choice_name):
                output, context, h, c = fold.add(
                    'tf_step',
                    fold.add('choice_emb',
                             self.model.choice_vocab.stoi(choice_name)),
                    *(state + item_memory)).split(4)
                return output, (context, h, c)

            def step_pointer(state, loc):
                output, context, h, c = fold.add(
                    'tf_step',
                    fold.add('tf_get_code_emb', current_code_memory, loc),
                    *(state + item_memory)).split(4)
                return output, (context, h, c)

            def log_prob(logits, idx, size):
                assert idx < size
                return fold.add(
                    'tf_get_log_prob:{}'.format(size),
                    fold.add('tf_torch_log_softmax:{}'.format(size), logits),
                    idx)

            def pointer_logits(output, loc):
                assert current_code_attn_mask[0, loc] == 0
                return fold.add('pointer_logits', output, current_code_memory,
                                current_code_attn_mask)

            def batched_sum(v1, v2, v3=zero, v4=zero, v5=zero):
                # v* shape: batch x 1
                return fold.add('tf_batched_sum', v1, v2, v3, v4, v5)

            for action_type, action_args in allowed_edits:
                if action_type == mutation.ADD_ACTION:
                    location, karel_action = action_args

                    action_log_prob = log_prob(
                        initial_logits[batch_idx],
                        self.model.initial_vocab.stoi(karel_action),
                        len(self.model.initial_vocab))

                    output, state = step(initial_state[batch_idx],
                                         karel_action)
                    loc_log_prob = log_prob(pointer_logits(output, location),
                                            location, max_code_length)

                    item_log_probs.append(
                        batched_sum(action_log_prob, loc_log_prob))

                elif action_type == mutation.WRAP_BLOCK:
                    block_type, cond_id, start, end = action_args

                    block_type_log_prob = log_prob(
                        initial_logits[batch_idx],
                        self.model.initial_vocab.stoi(block_type),
                        len(self.model.initial_vocab))
                    output, state = step(initial_state[batch_idx], block_type)

                    if block_type == 'repeat':
                        cond_log_prob = log_prob(
                            fold.add('repeat_logits', output), cond_id,
                            len(mutation.REPEAT_COUNTS))
                        cond = len(mutation.CONDS) + cond_id
                    else:
                        cond_log_prob = log_prob(
                            fold.add('cond_logits', output), cond_id,
                            len(mutation.CONDS))
                        cond = cond_id
                    output, state = step(state, cond)

                    start_log_prob = log_prob(pointer_logits(output, start),
                                              start, max_code_length)
                    output, state = step_pointer(state, start)

                    end_log_prob = log_prob(pointer_logits(output, end), end,
                                            max_code_length)

                    item_log_probs.append(
                        batched_sum(block_type_log_prob, cond_log_prob,
                                    start_log_prob, end_log_prob))

                elif action_type == mutation.WRAP_IFELSE:
                    cond_id, if_start, else_start, end = action_args

                    block_type_log_prob = log_prob(
                        initial_logits[batch_idx],
                        self.model.initial_vocab.stoi('ifElse'),
                        len(self.model.initial_vocab))
                    output, state = step(initial_state[batch_idx], 'ifElse')

                    cond_log_prob = log_prob(fold.add('cond_logits', output),
                                             cond_id, len(mutation.CONDS))
                    output, state = step(state, cond_id)

                    if_start_log_prob = log_prob(
                        pointer_logits(output, if_start), if_start,
                        max_code_length)
                    output, state = step_pointer(state, if_start)

                    else_start_log_prob = log_prob(
                        pointer_logits(output, else_start), else_start,
                        max_code_length)
                    output, state = step_pointer(state, else_start)

                    end_log_prob = log_prob(pointer_logits(output, end), end,
                                            max_code_length)

                    item_log_probs.append(
                        batched_sum(block_type_log_prob, cond_log_prob,
                                    if_start_log_prob, else_start_log_prob,
                                    end_log_prob))

            if not allowed_edits:
                item_log_probs.append(
                    log_prob(initial_logits[batch_idx],
                             len(self.model.initial_vocab) - 1,
                             len(self.model.initial_vocab)))

            log_probs.append(item_log_probs)

        # log_probs before: list (batch size) of list (allowed_edits)
        # log_probs after: list (batch size) of Tensor, each with length
        #                  `allowed_edits`
        log_probs_t = fold.apply(self.model, log_probs)
        log_probs_per_example = [utils.logsumexp(t) for t in log_probs_t]
        loss = -torch.mean(torch.cat(log_probs_per_example))

        return loss, log_probs_per_example
Example #10
0
def main():
    device_type = 'cuda' if args.cuda else 'cpu'
    device = torch.device(device_type)

    print("Running on: {}".format(device))

    #####################################
    ## configure experiment parameters ##
    #####################################
    batch_sizes = [1, 32, 64, 128, 256, 512, 1024]
    epochs = 1
    learning_rate = 0.001
    max_samples = 5000  # number of samples to use for experiment
    #####################################

    inputs = ParsedTextField(lower=True)
    transitions = ShiftReduceField()
    labels = data.Field(sequential=False)

    print("Loading dataset...")
    train, dev, test = datasets.SNLI.splits(inputs, labels, transitions)
    inputs.build_vocab(train, dev, test)
    labels.build_vocab(train)
    print("Done.")
    for batch_size in batch_sizes:
        print("Batching dataset into mini-batches of size {}..".format(
            batch_size))
        train_iter, _, _ = data.BucketIterator.splits((train, dev, test),
                                                      batch_size=batch_size,
                                                      device=device)
        print("Done.")

        print("Configuring SPINN model...")
        model = SPINN(3, 500, len(inputs.vocab))
        if args.cuda:
            model.to(device)
        criterion = nn.CrossEntropyLoss()
        opt = optim.Adam(model.parameters(), lr=learning_rate)
        print("Done.")

        for epoch in range(epochs):
            print("starting epoch {}".format(epoch))

            all_batch_times = []
            for batch_idx, batch in enumerate(train_iter):
                opt.zero_grad()  # reset gradients per mini-batch
                all_logits, all_labels = [], []

                if args.dynamic:
                    fold = torchfold.Fold()
                    if args.cuda:
                        fold.cuda()

                start = timer()
                tree_sizes = []
                # becuase batch.dataset starts at the begninning of the entire dataset
                # instead of where the previous batch left off
                for sample_idx in range(batch_idx * batch_size,
                                        (batch_idx + 1) * batch_size):
                    # HACK this is to account for the final batch which may or may not be
                    # of size batch_size - there should be a more elegant solution to this
                    if sample_idx == len(batch.dataset) - 1:
                        break

                    tree = Tree(batch.dataset[sample_idx].label,
                                batch.dataset[sample_idx].premise_transitions,
                                batch.dataset[sample_idx].premise,
                                inputs.vocab, labels.vocab)
                    if args.dynamic:
                        all_logits.append(encode_tree_fold(fold, tree))
                    else:
                        all_logits.append(encode_tree_regular(model, tree))
                    all_labels.append(tree.label)

                if args.dynamic:
                    res = fold.apply(model, [all_logits, all_labels])
                    batch_time = timer() - start
                    loss = criterion(res[0], res[1])
                else:
                    test = np.asarray(all_labels, dtype=int)
                    x = torch.from_numpy(test).to(device)
                    batch_time = timer() - start
                    loss = criterion(torch.cat(all_logits, 0), x)

                loss.backward()
                opt.step()

                ####################
                ## Gather results ##
                ####################
                all_batch_times.append(batch_time)
                results['time'].append(batch_time)
                results['epoch'].append(epoch)
                results['batch'].append(batch_idx)
                results['sample'].append(sample_idx)
                results['batch_size'].append(batch_size)
                ts = tree_size(tree.root)
                tree_sizes.append(ts)
                results['treesize'].append(np.mean(tree_sizes))
                ####################

                if batch_idx % 10 == 1:
                    print(
                        "batch size: {} sample: {}/{} loss:{:4f} - Avg. Time (per batch): {:5f}s"
                        .format(batch_size, batch_idx * batch_size,
                                max_samples, loss, np.mean(all_batch_times)))
                # only need to look at first 5000 samples for each batch
                if batch_idx * batch_size > max_samples:
                    break

            print("done epoch {}".format(epoch))

        with open(
                os.path.join(
                    ROOT, "results_fold-{}-{}-{}-backup.json".format(
                        args.dynamic, batch_size, args.cuda)), "w+") as fd:
            json.dump(results, fd)

    with open(
            os.path.join(
                ROOT,
                "results_fold-{}-{}-full.json".format(args.dynamic,
                                                      args.cuda)), "w+") as fd:
        json.dump(results, fd)
Example #11
0
grassData = GRASS('data')
dataloader = torch.utils.data.DataLoader(grassData,
                                         batch_size=123,
                                         shuffle=True,
                                         collate_fn=class_collate)

optimizer_encoder = torch.optim.Adam(encoder.parameters(), lr=1e-3)
optimizer_decoder = torch.optim.Adam(decoder.parameters(), lr=1e-3)

for epcho in range(500):
    if epcho % 100 == 0 and epcho != 0:
        torch.save(encoder, 'VAEencoder.pkl')
        torch.save(decoder, 'VAEdecoder.pkl')
    for i, batch in enumerate(dataloader):
        fold = torchfold.Fold(cuda=True, variable=False)
        encoding = []
        for example in batch:
            encoding.append(model.encode_structure_fold(fold, example))
        encoding = fold.apply(encoder, [encoding])
        encoding = torch.split(encoding[0], 1, 0)
        decodingLoss = []
        fold = torchfold.Fold(cuda=True, variable=True)
        kldLoss = []
        for example, f in zip(batch, encoding):
            ff, kld = torch.chunk(f, 2, 1)
            decodingLoss.append(model.decode_structure_fold(fold, ff, example))
            kldLoss.append(kld)
        decodingLoss = fold.apply(decoder, [decodingLoss, kldLoss])
        err_re = decodingLoss[0].sum() / len(batch)
        err_kld = decodingLoss[1].sum().mul(-0.05) / len(batch)
Example #12
0
    def _training_pass(self, valid_rooms, epoch, is_training=True):
        """
        Single training pass
        :param valid_rooms: choice of =[self.valid_rooms_train, self.valid_rooms_test]
        :param epoch: current epoch
        :param is_training: train or test pass
        :return:
        """

        ''' epoch and args '''
        epoch += self.pretrained_epoch
        opt_parser =self.opt_parser

        ''' current training state '''
        if (is_training):
            self.STATE = 'TRAIN'
            self.full_enc.train()
            self.full_dec.train()
        else:
            self.STATE = 'EVAL'
            self.full_enc.eval()
            self.full_dec.eval()

        ''' init loss / accuracy '''
        loss_cat_per_epoch, acc_cat_per_epoch, loss_dim_per_epoch, num_node_per_epoch, dim_acc_per_epoch = 0.0, {1:0.0, 3:0.0, 5:0.0}, 0.0, 0.0, 0.0

        ''' shuffle room list and create training batches '''
        shuffle(valid_rooms)
        room_indices = list(range(len(valid_rooms)))
        room_idx_batches = [room_indices[i: i + opt_parser.batch_size] for i in
                            range(0, len(valid_rooms), opt_parser.batch_size)]

        ''' Batch loop '''
        for batch_i, batch in enumerate(room_idx_batches):

            batch_rooms = [valid_rooms[i] for i in batch]

            """ ==================================================================
                                        Encoder Part
            ================================================================== """
            # init torchfold
            enc_fold = torchfold.Fold()
            enc_fold_nodes = []
            enc_rand_path_order = []
            enc_rand_path_root_to_leaf_order = []

            # loop for rooms
            for room_i, room in enumerate(batch_rooms):

                node_list = self.__preprocess_root_wall_nodes__(room['node_list'])

                # adapt acceleration for large graphs (by splitting into sub-graphs)
                consider_path_type = ['root']
                root_to_split = False

                if(opt_parser.adapt_training_on_large_graph):
                    if (len(node_list.keys()) >= int(opt_parser.max_scene_nodes)):
                        consider_path_type = node_list['root']['support']
                        root_to_split = True

                # loop for sub-graphs
                for sub_tree_root_node in consider_path_type:

                    # find sub-graph's root to leaf node path
                    subtree_to_leaf_path = self.find_root_to_leaf_node_path(node_list, cur_node=sub_tree_root_node)

                    # skip unreasonable paths
                    subtree_to_leaf_path = [p for p in subtree_to_leaf_path if len(p) >= 2 and len(p) < opt_parser.max_scene_nodes]
                    subtree_to_leaf_path = [p for p in subtree_to_leaf_path if 'wall' not in p[-1].split('_')[0]]
                    if(len(subtree_to_leaf_path) == 0):
                        continue

                    # find node list for sub-graphs
                    sub_keys = list(set(self.find_selected_node_list(node_list, sub_tree_root_node)))
                    if(root_to_split):
                        sub_keys += ['root']
                    sub_node_list = dict((k, node_list[k]) for k in sub_keys if k in node_list.keys())

                    # update parents, childs, siblings for each node
                    sub_node_list = self.find_parent_sibling_child_list(sub_node_list)

                    # exclude examples with too many sub tree nodes
                    if(len(sub_node_list.keys()) >= int(opt_parser.max_scene_nodes)):
                        print('skip too large sub-scene:', len(sub_node_list.keys()), '>', opt_parser.max_scene_nodes)
                        continue

                    subtree_to_leaf_path.sort()
                    # loop for each root-to-leaf path
                    for rand_path_idx, rand_path in enumerate(subtree_to_leaf_path):

                        rand_path_fold, rand_path_node_name_order = self.model.encode_tree_fold(enc_fold, sub_node_list, rand_path, opt_parser)
                        enc_fold_nodes += rand_path_fold
                        enc_rand_path_order += [[room_i, sub_tree_root_node] + rand_path_node_name_order]
                        enc_rand_path_root_to_leaf_order += [rand_path]

            # if batch size is too small, sometimes there is no valid training instance.
            if(len(enc_fold_nodes) == 0):
                print('surprisingly this batch has no valid training trees!')
                continue


            # torch-fold train encoder
            enc_fold_nodes = enc_fold.apply(self.full_enc, [enc_fold_nodes])
            enc_fold_nodes = torch.split(enc_fold_nodes[0], 1, 0)

            """ ==================================================================
                                        Decoder Part
            ================================================================== """
            # Ground-truth leaf node vec
            leaf_node_gt = []

            # # FOLD
            dec_fold = torchfold.Fold()
            dec_fold_nodes = []

            # loop for all encoded vectors
            for i, rand_path_order in enumerate(enc_rand_path_order):

                # find room-node-list
                room_i = rand_path_order[0]
                node_list = batch_rooms[room_i]['node_list']

                # decode to k-vec and add Ops to fold
                dec_fold_nodes.append(self.model.decode_tree_fold(dec_fold, enc_fold_nodes[i], opt_parser))
                leaf_node_gt += [self.model.get_gt_k_vec(node_list, enc_rand_path_root_to_leaf_order[i][-1], opt_parser)]  # leaf node ground-truth k-vec

            # torch-fold decoder
            dec_fold_nodes = dec_fold.apply(self.full_dec, [dec_fold_nodes])
            leaf_node_pred = dec_fold_nodes[0]

            """ ==================================================================
                                      Loss / Accuray Part
            ================================================================== """
            size_pos_dim = 6

            leaf_node_cat_gt = [c[:-size_pos_dim].index(1) for c in leaf_node_gt]
            leaf_node_cat_gt = to_torch(leaf_node_cat_gt, torch_type=torch.LongTensor, dim_0=len(leaf_node_gt)).view(-1)

            leaf_node_dim_gt = [c[-size_pos_dim:-size_pos_dim+3] for c in leaf_node_gt]
            leaf_node_dim_gt = to_torch(leaf_node_dim_gt, torch_type=torch.FloatTensor, dim_0=len(leaf_node_gt))

            loss_cat = self.LOSS_CLS(leaf_node_pred[:, :-size_pos_dim], leaf_node_cat_gt)
            loss_dim = self.LOSS_L2(leaf_node_pred[:, -size_pos_dim:-size_pos_dim+3], leaf_node_dim_gt) * 1000

            # report scores
            loss_cat_per_batch = loss_cat.data.cpu().numpy()
            loss_dim_per_batch = loss_dim.data.cpu().numpy()
            num_node_per_batch = len(leaf_node_gt) * 1.0

            # accuracy (top k)
            acc_cat_per_batch = {}
            for k in [1, 3, 5]:
                _, pred = leaf_node_pred[:, :-size_pos_dim].topk(k, 1, True, True)
                pred = pred.t()
                correct = pred.eq(leaf_node_cat_gt.view(1, -1).expand_as(pred))
                correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
                acc_cat_per_batch[k] = correct_k[0].cpu().numpy()

            # dimension (diagonal) percentage off
            diag_pred = np.sqrt(
                np.sum(leaf_node_pred[:, -size_pos_dim:-size_pos_dim+3].data.cpu().numpy() ** 2, axis=1))
            diag_gt = np.sqrt(
                np.sum(leaf_node_dim_gt.data.cpu().numpy() ** 2, axis=1))
            dim_acc_per_batch = np.sum(np.abs(diag_pred - diag_gt) / diag_gt)

            loss_cat_per_epoch += loss_cat_per_batch
            loss_dim_per_epoch += loss_dim_per_batch
            num_node_per_epoch += num_node_per_batch
            dim_acc_per_epoch += dim_acc_per_batch
            for key in acc_cat_per_epoch.keys():
                acc_cat_per_epoch[key] += acc_cat_per_batch[key]

            if (is_training):

                # Back-propagation
                for key in self.opt.keys():
                    self.opt[key].zero_grad()

                # only train object dimensions
                if(opt_parser.train_dim and not opt_parser.train_cat):
                    loss_dim.backward()
                # only train object categories
                elif(opt_parser.train_cat and not opt_parser.train_dim):
                    loss_cat.backward()
                # train both
                elif(opt_parser.train_cat and opt_parser.train_dim):
                    loss_cat.backward(retain_graph=True)
                    loss_dim.backward()
                else:
                    print('At least enable --train_cat or --train_dim.')
                    exit(-1)

                for key in self.opt.keys():
                    self.opt[key].step()

            if (opt_parser.verbose >= 0):
                print(self.STATE, opt_parser.name, epoch,
                      ': ({}/{}:{})'.format(batch_i, len(room_idx_batches), num_node_per_batch),
                      'CAT Loss: {:.4f}, Acc_1: {:.4f}, Acc_3: {:.4f}, Acc_5: {:.4f},Dim Loss: {:.8f}, dim acc: {:.2f}'.format(
                          loss_cat_per_batch / num_node_per_batch * 100.0,
                          acc_cat_per_batch[1] / num_node_per_batch,
                          acc_cat_per_batch[3] / num_node_per_batch,
                          acc_cat_per_batch[5] / num_node_per_batch,
                          loss_dim_per_batch / num_node_per_batch,
                          dim_acc_per_batch / num_node_per_batch))

        """ ==================================================================
                                  Report Part
        ================================================================== """

        print('========================================================')
        print(self.STATE, epoch, ': ',
              'CAT Loss: {:.4f}, Acc_1: {:.4f}, Acc_3: {:.4f}, Acc_5: {:.4f}, Dim Loss: {:.4f}, Dim acc: {:.4f}'.format(
                  loss_cat_per_epoch / num_node_per_epoch * 100.0,
                  acc_cat_per_epoch[1] / num_node_per_epoch,
                  acc_cat_per_epoch[3] / num_node_per_epoch,
                  acc_cat_per_epoch[5] / num_node_per_epoch,
                  loss_dim_per_epoch / num_node_per_epoch,
                  dim_acc_per_epoch / num_node_per_epoch))
        print('========================================================')

        ''' write avg to log '''
        if (opt_parser.write):
            self.writer.add_scalar('{}_LOSS_CAT'.format(self.STATE), loss_cat_per_epoch / num_node_per_epoch,
                              epoch)
            self.writer.add_scalar('{}_ACC_CAT'.format(self.STATE), acc_cat_per_epoch[1] / num_node_per_epoch,
                              epoch)
            self.writer.add_scalar('{}_ACC_3_CAT'.format(self.STATE), acc_cat_per_epoch[3] / num_node_per_epoch,
                              epoch)
            self.writer.add_scalar('{}_ACC_5_CAT'.format(self.STATE), acc_cat_per_epoch[5] / num_node_per_epoch,
                              epoch)
            self.writer.add_scalar('{}_LOSS_DIM'.format(self.STATE), loss_dim_per_epoch / num_node_per_epoch,
                              epoch)

        ''' save model '''
        if (not is_training):
            def save_model(save_type):
                torch.save({
                    'full_enc_state_dict': self.full_enc.state_dict(),
                    'full_dec_state_dict': self.full_dec.state_dict(),
                    'full_enc_opt': self.opt['full_enc'].state_dict(),
                    'full_dec_opt': self.opt['full_dec'].state_dict(),
                    'epoch': epoch
                }, '{}/Entire_model_{}.pth'.format(opt_parser.outf, save_type))

            # if model is better, save model checkpoint
            # min dim loss model
            if(loss_dim_per_epoch / num_node_per_epoch < self.MIN_DIM_LOSS):
                self.MIN_DIM_LOSS = loss_dim_per_epoch / num_node_per_epoch
                save_model('min_dim_loss')
            # max cat acc model (top-5 acc)
            if (acc_cat_per_epoch[5] / num_node_per_epoch > self.MAX_ACC):
                self.MAX_ACC = acc_cat_per_epoch[5] / num_node_per_epoch
                save_model('max_acc')
            # min cat loss model
            if (loss_cat_per_epoch / num_node_per_epoch < self.MIN_LOSS):
                self.MIN_LOSS = loss_cat_per_epoch / num_node_per_epoch
                save_model('min_loss')
            # always save the latest model
            save_model('last_epoch')

        return