def optimize_model(self,): import time startTime = time.time() samples = self.Memory.sample(64) value_now_list = [] next_value_list = [] if (len(samples)==0): return fold = torchfold.Fold(cuda=True) nowL = [] for one_sample in samples: nowList = one_sample.env.selectValueFold(fold) nowL.append(len(nowList)) value_now_list+=nowList res = fold.apply(self.policy_net, [value_now_list])[0] total = 0 value_now_list = [] next_value_list = [] for idx,one_sample in enumerate(samples): value_now_list.append(self.policy_net.logits(res[total:total+nowL[idx]] , one_sample.env.sel.join_matrix )) next_value_list.append(one_sample.next_value) total += nowL[idx] value_now = torch.cat(value_now_list,dim = 0) next_value = torch.cat(next_value_list,dim = 0) endTime = time.time() if True: loss = F.smooth_l1_loss(value_now,next_value,size_average=True) self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss.item() return None
def forward(self, state, action): self.clear_buffer() if not self.disable_fold: self.fold = torchfold.Fold() self.fold.cuda() self.zeroFold_td = self.fold.add("zero_func_td") self.zeroFold_bu = self.fold.add("zero_func_bu") self.x1_fold, self.x2_fold = [], [] assert state.shape[ 1] == self.state_dim * self.num_limbs, 'state.shape[1] expects {} but got {} with num_limbs being {} and state_dim being {}'.format( self.state_dim * self.num_limbs, state.shape[1], self.num_limbs, self.state_dim) for i in range(self.num_limbs): self.input_state[i] = state[:, i * self.state_dim:(i + 1) * self.state_dim] self.input_action[i] = action[:, i] self.input_action[i] = torch.unsqueeze(self.input_action[i], -1) if not self.disable_fold: self.input_state[i] = torch.unsqueeze(self.input_state[i], 0) self.input_action[i] = torch.unsqueeze(self.input_action[i], 0) if self.bu: # bottom up transmission by recursion for i in range(self.num_limbs): self.bottom_up_transmission(i) if self.td: # top down transmission by recursion for i in range(self.num_limbs): self.top_down_transmission(i) if not self.bu and not self.td: for i in range(self.num_limbs): if not self.disable_fold: self.x1[i], self.x2[i] = self.fold.add( 'critic' + str(0).zfill(3), self.input_state[i], self.input_action[i]).split(2) else: self.x1[i], self.x2[i] = self.critic[i]( self.input_state[i], self.input_action[i]) if not self.disable_fold: if self.bu and not self.td: self.x1_fold = self.x1_fold + [self.x1] self.x2_fold = self.x2_fold + [self.x2] else: self.x1_fold = self.x1_fold + self.x1 self.x2_fold = self.x2_fold + self.x2 self.x1, self.x2 = self.fold.apply(self, [self.x1_fold, self.x2_fold]) self.x1 = torch.transpose(self.x1, 0, 1) self.x2 = torch.transpose(self.x2, 0, 1) self.fold = None else: self.x1 = torch.stack(self.x1, dim=-1) # (bs,num_limbs,1) self.x2 = torch.stack(self.x2, dim=-1) return torch.sum(self.x1, dim=-1), torch.sum(self.x2, dim=-1)
def forward(self, state, mode="train"): self.clear_buffer() if mode == "inference": temp = self.batch_size self.batch_size = 1 if not self.disable_fold: self.fold = torchfold.Fold() self.fold.cuda() self.zeroFold_td = self.fold.add("zero_func_td") self.zeroFold_bu = self.fold.add("zero_func_bu") self.a = [] assert ( state.shape[1] == self.state_dim * self.num_limbs ), "state.shape[1] expects {} but got {} with num_limbs being {} and state_dim being {}".format( self.state_dim * self.num_limbs, state.shape[1], self.num_limbs, self.state_dim, ) for i in range(self.num_limbs): self.input_state[i] = state[:, i * self.state_dim:(i + 1) * self.state_dim] if not self.disable_fold: self.input_state[i] = torch.unsqueeze(self.input_state[i], 0) if self.bu: # bottom up transmission by recursion for i in range(self.num_limbs): self.bottom_up_transmission(i) if self.td: # top down transmission by recursion for i in range(self.num_limbs): self.top_down_transmission(i) if not self.bu and not self.td: for i in range(self.num_limbs): if not self.disable_fold: self.action[i] = self.fold.add("actor" + str(0).zfill(3), self.input_state[i]) else: self.action[i] = self.actor[i](self.input_state[i]) if not self.disable_fold: self.a += self.action self.action = self.fold.apply(self, [self.a])[0] self.action = torch.transpose(self.action, 0, 1) self.fold = None else: self.action = torch.stack(self.action, dim=-1) self.msg_down = torch.stack(self.msg_down, dim=-1) if mode == "inference": self.batch_size = temp return torch.squeeze(self.action)
def test_rnn(self): f = torchfold.Fold() v1, _ = f.add('value2', 1).split(2) v2, _ = f.add('value2', 2).split(2) r = v1 for i in range(1000): r = f.add('attr', v1, v2) r = f.add('attr', r, v2) te = TestEncoder() enc = f.apply(te, [[r]]) self.assertEqual(enc[0].size(), (1, 10))
def Q1(self, state, action): self.clear_buffer() if not self.disable_fold: self.fold = torchfold.Fold() self.fold.cuda() self.zeroFold_td = self.fold.add("zero_func_td") self.zeroFold_bu = self.fold.add("zero_func_bu") self.x1_fold = [] for i in range(self.num_limbs): self.input_state[i] = state[:, i * self.state_dim:(i + 1) * self.state_dim] self.input_action[i] = action[:, i] self.input_action[i] = torch.unsqueeze(self.input_action[i], -1) if not self.disable_fold: self.input_state[i] = torch.unsqueeze(self.input_state[i], 0) self.input_action[i] = torch.unsqueeze(self.input_action[i], 0) if self.bu: # bottom up transmission by recursion for i in range(self.num_limbs): self.bottom_up_transmission(i) if self.td: # top down transmission by recursion for i in range(self.num_limbs): self.top_down_transmission(i) if not self.bu and not self.td: for i in range(self.num_limbs): if not self.disable_fold: self.x1[i] = self.fold.add( "critic" + str(0).zfill(3), self.input_state[i], self.input_action[i], ) else: self.x1[i] = self.critic[i](self.input_state[i], self.input_action[i]) if not self.disable_fold: if self.bu and not self.td: self.x1 = [self.x1] self.x1_fold = self.x1_fold + self.x1 self.x1 = self.fold.apply(self, [self.x1_fold])[0] if not self.bu and not self.td: self.x1 = self.x1[0] self.x1 = torch.transpose(self.x1, 0, 1) self.fold = None else: self.x1 = torch.stack(self.x1, dim=-1) # (bs,num_limbs,1) return torch.sum(self.x1, dim=-1)
def test_nobatch(self): f = torchfold.Fold() v = [] for i in range(15): v.append(f.add('value', i % 10)) d = f.add('concat', *v).nobatch() res = [] for i in range(100): res.append(f.add('logits', v[i % 10], d)) te = TestEncoder() enc = f.apply(te, [res]) self.assertEqual(len(enc), 1) self.assertEqual(enc[0].size(), (100, 15))
def main(): inputs = datasets.snli.ParsedTextField(lower=True) transitions = datasets.snli.ShiftReduceField() answers = data.Field(sequential=False) train, dev, test = datasets.SNLI.splits(inputs, answers, transitions) inputs.build_vocab(train, dev, test) answers.build_vocab(train) train_iter, dev_iter, test_iter = data.BucketIterator.splits( (train, dev, test), batch_size=args.batch_size, device=0 if args.cuda else -1) model = SPINN(3, 500, 1000) criterion = nn.CrossEntropyLoss() opt = optim.Adam(model.parameters(), lr=0.01) for epoch in range(10): start = time.time() iteration = 0 for batch_idx, batch in enumerate(train_iter): opt.zero_grad() all_logits, all_labels = [], [] fold = torchfold.Fold(cuda=args.cuda) for example in batch.dataset: tree = Tree(example, inputs.vocab, answers.vocab) if args.fold: all_logits.append(encode_tree_fold(fold, tree)) else: all_logits.append(encode_tree_regular(model, tree)) all_labels.append(tree.label) if args.fold: res = fold.apply(model, [all_logits, all_labels]) loss = criterion(res[0], res[1]) else: loss = criterion(torch.cat(all_logits, 0), Variable(torch.LongTensor(all_labels))) loss.backward() opt.step() iteration += 1 if iteration % 10 == 0: print("Avg. Time: %fs" % ((time.time() - start) / iteration))
def test_rnn_optimized_chunking(self): seq_lengths = [2, 3, 5] states = [] for _ in xrange(len(seq_lengths)): states.append(self._generate_variable(self.num_units)) f = torchfold.Fold() for seq_ind in xrange(len(seq_lengths)): for _ in xrange(seq_lengths[seq_ind]): states[seq_ind] = f.add( 'encode', self._generate_variable(self.input_size), states[seq_ind]) enc = RNNEncoder(self.num_units, self.input_size) with mock.patch.object(torch, 'chunk', wraps=torch.chunk) as wrapped_chunk: result = f.apply(enc, [states]) # torch.chunk is called 3 times instead of max(seq_lengths)=5. self.assertEquals(3, wrapped_chunk.call_count) self.assertEqual(len(result), 1) self.assertEqual(result[0].size(), (len(seq_lengths), self.num_units))
def compute_loss_annotated(self, batch): batch_size = batch.batch_size if self.args.cuda: batch = batch.cuda_train() initial_state, memory, initial_logits = self.model.prepare_initial( batch.input_grids, batch.output_grids, batch.current_grids, batch.current_code) max_code_length = memory.current_code.memory.shape[2] initial_state_orig = initial_state # state before: (batch x num pairs x hidden, # num layers x batch x num pairs x hidden, # num layers x batch x num pairs x hidden) # state after: list of (1 x num pairs x hidden, # 1 x num layers x num pairs x hidden, # 1 x num layers x num pairs x hidden) initial_state = zip( torch.chunk(initial_state.context, batch_size), torch.chunk(initial_state.h.permute(1, 0, 2, 3), batch_size), torch.chunk(initial_state.c.permute(1, 0, 2, 3), batch_size)) # memory: io, current_grid, current_code.memory, current_code.attn_mask # before: (batch x num pairs x 512, # batch x num pairs x 256, # batch x num pairs x max code length x 512, # batch x num pairs x max code length) # after: list of (1 x num pairs x 512, # 1 x num pairs x 256, # 1 x num pairs x max code length x 512, # 1 x num pairs x max code length) memory = zip(*(torch.chunk(t, batch_size) for t in memory.to_flat())) initial_logits = torch.chunk(initial_logits, batch_size) fold = torchfold.Fold(cuda=self.args.cuda) #fold = torchfold.Unfold(nn=self.model, cuda=self.args.cuda) zero = fold.add('tf_torch_zero') log_probs = [] for batch_idx, allowed_edits in enumerate(batch.allowed_edits): item_log_probs = [] item_memory = memory[batch_idx] # before: 1 x num pairs x length x hidden size # after: 1 x length x hidden size current_code_memory = item_memory[2][:, 0] current_code_attn_mask = item_memory[3][:, 0] def step(state, choice_name): output, context, h, c = fold.add( 'tf_step', fold.add('choice_emb', self.model.choice_vocab.stoi(choice_name)), *(state + item_memory)).split(4) return output, (context, h, c) def step_pointer(state, loc): output, context, h, c = fold.add( 'tf_step', fold.add('tf_get_code_emb', current_code_memory, loc), *(state + item_memory)).split(4) return output, (context, h, c) def log_prob(logits, idx, size): assert idx < size return fold.add( 'tf_get_log_prob:{}'.format(size), fold.add('tf_torch_log_softmax:{}'.format(size), logits), idx) def pointer_logits(output, loc): assert current_code_attn_mask[0, loc] == 0 return fold.add('pointer_logits', output, current_code_memory, current_code_attn_mask) def batched_sum(v1, v2, v3=zero, v4=zero, v5=zero): # v* shape: batch x 1 return fold.add('tf_batched_sum', v1, v2, v3, v4, v5) for action_type, action_args in allowed_edits: if action_type == mutation.ADD_ACTION: location, karel_action = action_args action_log_prob = log_prob( initial_logits[batch_idx], self.model.initial_vocab.stoi(karel_action), len(self.model.initial_vocab)) output, state = step(initial_state[batch_idx], karel_action) loc_log_prob = log_prob(pointer_logits(output, location), location, max_code_length) item_log_probs.append( batched_sum(action_log_prob, loc_log_prob)) elif action_type == mutation.WRAP_BLOCK: block_type, cond_id, start, end = action_args block_type_log_prob = log_prob( initial_logits[batch_idx], self.model.initial_vocab.stoi(block_type), len(self.model.initial_vocab)) output, state = step(initial_state[batch_idx], block_type) if block_type == 'repeat': cond_log_prob = log_prob( fold.add('repeat_logits', output), cond_id, len(mutation.REPEAT_COUNTS)) cond = len(mutation.CONDS) + cond_id else: cond_log_prob = log_prob( fold.add('cond_logits', output), cond_id, len(mutation.CONDS)) cond = cond_id output, state = step(state, cond) start_log_prob = log_prob(pointer_logits(output, start), start, max_code_length) output, state = step_pointer(state, start) end_log_prob = log_prob(pointer_logits(output, end), end, max_code_length) item_log_probs.append( batched_sum(block_type_log_prob, cond_log_prob, start_log_prob, end_log_prob)) elif action_type == mutation.WRAP_IFELSE: cond_id, if_start, else_start, end = action_args block_type_log_prob = log_prob( initial_logits[batch_idx], self.model.initial_vocab.stoi('ifElse'), len(self.model.initial_vocab)) output, state = step(initial_state[batch_idx], 'ifElse') cond_log_prob = log_prob(fold.add('cond_logits', output), cond_id, len(mutation.CONDS)) output, state = step(state, cond_id) if_start_log_prob = log_prob( pointer_logits(output, if_start), if_start, max_code_length) output, state = step_pointer(state, if_start) else_start_log_prob = log_prob( pointer_logits(output, else_start), else_start, max_code_length) output, state = step_pointer(state, else_start) end_log_prob = log_prob(pointer_logits(output, end), end, max_code_length) item_log_probs.append( batched_sum(block_type_log_prob, cond_log_prob, if_start_log_prob, else_start_log_prob, end_log_prob)) if not allowed_edits: item_log_probs.append( log_prob(initial_logits[batch_idx], len(self.model.initial_vocab) - 1, len(self.model.initial_vocab))) log_probs.append(item_log_probs) # log_probs before: list (batch size) of list (allowed_edits) # log_probs after: list (batch size) of Tensor, each with length # `allowed_edits` log_probs_t = fold.apply(self.model, log_probs) log_probs_per_example = [utils.logsumexp(t) for t in log_probs_t] loss = -torch.mean(torch.cat(log_probs_per_example)) return loss, log_probs_per_example
def main(): device_type = 'cuda' if args.cuda else 'cpu' device = torch.device(device_type) print("Running on: {}".format(device)) ##################################### ## configure experiment parameters ## ##################################### batch_sizes = [1, 32, 64, 128, 256, 512, 1024] epochs = 1 learning_rate = 0.001 max_samples = 5000 # number of samples to use for experiment ##################################### inputs = ParsedTextField(lower=True) transitions = ShiftReduceField() labels = data.Field(sequential=False) print("Loading dataset...") train, dev, test = datasets.SNLI.splits(inputs, labels, transitions) inputs.build_vocab(train, dev, test) labels.build_vocab(train) print("Done.") for batch_size in batch_sizes: print("Batching dataset into mini-batches of size {}..".format( batch_size)) train_iter, _, _ = data.BucketIterator.splits((train, dev, test), batch_size=batch_size, device=device) print("Done.") print("Configuring SPINN model...") model = SPINN(3, 500, len(inputs.vocab)) if args.cuda: model.to(device) criterion = nn.CrossEntropyLoss() opt = optim.Adam(model.parameters(), lr=learning_rate) print("Done.") for epoch in range(epochs): print("starting epoch {}".format(epoch)) all_batch_times = [] for batch_idx, batch in enumerate(train_iter): opt.zero_grad() # reset gradients per mini-batch all_logits, all_labels = [], [] if args.dynamic: fold = torchfold.Fold() if args.cuda: fold.cuda() start = timer() tree_sizes = [] # becuase batch.dataset starts at the begninning of the entire dataset # instead of where the previous batch left off for sample_idx in range(batch_idx * batch_size, (batch_idx + 1) * batch_size): # HACK this is to account for the final batch which may or may not be # of size batch_size - there should be a more elegant solution to this if sample_idx == len(batch.dataset) - 1: break tree = Tree(batch.dataset[sample_idx].label, batch.dataset[sample_idx].premise_transitions, batch.dataset[sample_idx].premise, inputs.vocab, labels.vocab) if args.dynamic: all_logits.append(encode_tree_fold(fold, tree)) else: all_logits.append(encode_tree_regular(model, tree)) all_labels.append(tree.label) if args.dynamic: res = fold.apply(model, [all_logits, all_labels]) batch_time = timer() - start loss = criterion(res[0], res[1]) else: test = np.asarray(all_labels, dtype=int) x = torch.from_numpy(test).to(device) batch_time = timer() - start loss = criterion(torch.cat(all_logits, 0), x) loss.backward() opt.step() #################### ## Gather results ## #################### all_batch_times.append(batch_time) results['time'].append(batch_time) results['epoch'].append(epoch) results['batch'].append(batch_idx) results['sample'].append(sample_idx) results['batch_size'].append(batch_size) ts = tree_size(tree.root) tree_sizes.append(ts) results['treesize'].append(np.mean(tree_sizes)) #################### if batch_idx % 10 == 1: print( "batch size: {} sample: {}/{} loss:{:4f} - Avg. Time (per batch): {:5f}s" .format(batch_size, batch_idx * batch_size, max_samples, loss, np.mean(all_batch_times))) # only need to look at first 5000 samples for each batch if batch_idx * batch_size > max_samples: break print("done epoch {}".format(epoch)) with open( os.path.join( ROOT, "results_fold-{}-{}-{}-backup.json".format( args.dynamic, batch_size, args.cuda)), "w+") as fd: json.dump(results, fd) with open( os.path.join( ROOT, "results_fold-{}-{}-full.json".format(args.dynamic, args.cuda)), "w+") as fd: json.dump(results, fd)
grassData = GRASS('data') dataloader = torch.utils.data.DataLoader(grassData, batch_size=123, shuffle=True, collate_fn=class_collate) optimizer_encoder = torch.optim.Adam(encoder.parameters(), lr=1e-3) optimizer_decoder = torch.optim.Adam(decoder.parameters(), lr=1e-3) for epcho in range(500): if epcho % 100 == 0 and epcho != 0: torch.save(encoder, 'VAEencoder.pkl') torch.save(decoder, 'VAEdecoder.pkl') for i, batch in enumerate(dataloader): fold = torchfold.Fold(cuda=True, variable=False) encoding = [] for example in batch: encoding.append(model.encode_structure_fold(fold, example)) encoding = fold.apply(encoder, [encoding]) encoding = torch.split(encoding[0], 1, 0) decodingLoss = [] fold = torchfold.Fold(cuda=True, variable=True) kldLoss = [] for example, f in zip(batch, encoding): ff, kld = torch.chunk(f, 2, 1) decodingLoss.append(model.decode_structure_fold(fold, ff, example)) kldLoss.append(kld) decodingLoss = fold.apply(decoder, [decodingLoss, kldLoss]) err_re = decodingLoss[0].sum() / len(batch) err_kld = decodingLoss[1].sum().mul(-0.05) / len(batch)
def _training_pass(self, valid_rooms, epoch, is_training=True): """ Single training pass :param valid_rooms: choice of =[self.valid_rooms_train, self.valid_rooms_test] :param epoch: current epoch :param is_training: train or test pass :return: """ ''' epoch and args ''' epoch += self.pretrained_epoch opt_parser =self.opt_parser ''' current training state ''' if (is_training): self.STATE = 'TRAIN' self.full_enc.train() self.full_dec.train() else: self.STATE = 'EVAL' self.full_enc.eval() self.full_dec.eval() ''' init loss / accuracy ''' loss_cat_per_epoch, acc_cat_per_epoch, loss_dim_per_epoch, num_node_per_epoch, dim_acc_per_epoch = 0.0, {1:0.0, 3:0.0, 5:0.0}, 0.0, 0.0, 0.0 ''' shuffle room list and create training batches ''' shuffle(valid_rooms) room_indices = list(range(len(valid_rooms))) room_idx_batches = [room_indices[i: i + opt_parser.batch_size] for i in range(0, len(valid_rooms), opt_parser.batch_size)] ''' Batch loop ''' for batch_i, batch in enumerate(room_idx_batches): batch_rooms = [valid_rooms[i] for i in batch] """ ================================================================== Encoder Part ================================================================== """ # init torchfold enc_fold = torchfold.Fold() enc_fold_nodes = [] enc_rand_path_order = [] enc_rand_path_root_to_leaf_order = [] # loop for rooms for room_i, room in enumerate(batch_rooms): node_list = self.__preprocess_root_wall_nodes__(room['node_list']) # adapt acceleration for large graphs (by splitting into sub-graphs) consider_path_type = ['root'] root_to_split = False if(opt_parser.adapt_training_on_large_graph): if (len(node_list.keys()) >= int(opt_parser.max_scene_nodes)): consider_path_type = node_list['root']['support'] root_to_split = True # loop for sub-graphs for sub_tree_root_node in consider_path_type: # find sub-graph's root to leaf node path subtree_to_leaf_path = self.find_root_to_leaf_node_path(node_list, cur_node=sub_tree_root_node) # skip unreasonable paths subtree_to_leaf_path = [p for p in subtree_to_leaf_path if len(p) >= 2 and len(p) < opt_parser.max_scene_nodes] subtree_to_leaf_path = [p for p in subtree_to_leaf_path if 'wall' not in p[-1].split('_')[0]] if(len(subtree_to_leaf_path) == 0): continue # find node list for sub-graphs sub_keys = list(set(self.find_selected_node_list(node_list, sub_tree_root_node))) if(root_to_split): sub_keys += ['root'] sub_node_list = dict((k, node_list[k]) for k in sub_keys if k in node_list.keys()) # update parents, childs, siblings for each node sub_node_list = self.find_parent_sibling_child_list(sub_node_list) # exclude examples with too many sub tree nodes if(len(sub_node_list.keys()) >= int(opt_parser.max_scene_nodes)): print('skip too large sub-scene:', len(sub_node_list.keys()), '>', opt_parser.max_scene_nodes) continue subtree_to_leaf_path.sort() # loop for each root-to-leaf path for rand_path_idx, rand_path in enumerate(subtree_to_leaf_path): rand_path_fold, rand_path_node_name_order = self.model.encode_tree_fold(enc_fold, sub_node_list, rand_path, opt_parser) enc_fold_nodes += rand_path_fold enc_rand_path_order += [[room_i, sub_tree_root_node] + rand_path_node_name_order] enc_rand_path_root_to_leaf_order += [rand_path] # if batch size is too small, sometimes there is no valid training instance. if(len(enc_fold_nodes) == 0): print('surprisingly this batch has no valid training trees!') continue # torch-fold train encoder enc_fold_nodes = enc_fold.apply(self.full_enc, [enc_fold_nodes]) enc_fold_nodes = torch.split(enc_fold_nodes[0], 1, 0) """ ================================================================== Decoder Part ================================================================== """ # Ground-truth leaf node vec leaf_node_gt = [] # # FOLD dec_fold = torchfold.Fold() dec_fold_nodes = [] # loop for all encoded vectors for i, rand_path_order in enumerate(enc_rand_path_order): # find room-node-list room_i = rand_path_order[0] node_list = batch_rooms[room_i]['node_list'] # decode to k-vec and add Ops to fold dec_fold_nodes.append(self.model.decode_tree_fold(dec_fold, enc_fold_nodes[i], opt_parser)) leaf_node_gt += [self.model.get_gt_k_vec(node_list, enc_rand_path_root_to_leaf_order[i][-1], opt_parser)] # leaf node ground-truth k-vec # torch-fold decoder dec_fold_nodes = dec_fold.apply(self.full_dec, [dec_fold_nodes]) leaf_node_pred = dec_fold_nodes[0] """ ================================================================== Loss / Accuray Part ================================================================== """ size_pos_dim = 6 leaf_node_cat_gt = [c[:-size_pos_dim].index(1) for c in leaf_node_gt] leaf_node_cat_gt = to_torch(leaf_node_cat_gt, torch_type=torch.LongTensor, dim_0=len(leaf_node_gt)).view(-1) leaf_node_dim_gt = [c[-size_pos_dim:-size_pos_dim+3] for c in leaf_node_gt] leaf_node_dim_gt = to_torch(leaf_node_dim_gt, torch_type=torch.FloatTensor, dim_0=len(leaf_node_gt)) loss_cat = self.LOSS_CLS(leaf_node_pred[:, :-size_pos_dim], leaf_node_cat_gt) loss_dim = self.LOSS_L2(leaf_node_pred[:, -size_pos_dim:-size_pos_dim+3], leaf_node_dim_gt) * 1000 # report scores loss_cat_per_batch = loss_cat.data.cpu().numpy() loss_dim_per_batch = loss_dim.data.cpu().numpy() num_node_per_batch = len(leaf_node_gt) * 1.0 # accuracy (top k) acc_cat_per_batch = {} for k in [1, 3, 5]: _, pred = leaf_node_pred[:, :-size_pos_dim].topk(k, 1, True, True) pred = pred.t() correct = pred.eq(leaf_node_cat_gt.view(1, -1).expand_as(pred)) correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) acc_cat_per_batch[k] = correct_k[0].cpu().numpy() # dimension (diagonal) percentage off diag_pred = np.sqrt( np.sum(leaf_node_pred[:, -size_pos_dim:-size_pos_dim+3].data.cpu().numpy() ** 2, axis=1)) diag_gt = np.sqrt( np.sum(leaf_node_dim_gt.data.cpu().numpy() ** 2, axis=1)) dim_acc_per_batch = np.sum(np.abs(diag_pred - diag_gt) / diag_gt) loss_cat_per_epoch += loss_cat_per_batch loss_dim_per_epoch += loss_dim_per_batch num_node_per_epoch += num_node_per_batch dim_acc_per_epoch += dim_acc_per_batch for key in acc_cat_per_epoch.keys(): acc_cat_per_epoch[key] += acc_cat_per_batch[key] if (is_training): # Back-propagation for key in self.opt.keys(): self.opt[key].zero_grad() # only train object dimensions if(opt_parser.train_dim and not opt_parser.train_cat): loss_dim.backward() # only train object categories elif(opt_parser.train_cat and not opt_parser.train_dim): loss_cat.backward() # train both elif(opt_parser.train_cat and opt_parser.train_dim): loss_cat.backward(retain_graph=True) loss_dim.backward() else: print('At least enable --train_cat or --train_dim.') exit(-1) for key in self.opt.keys(): self.opt[key].step() if (opt_parser.verbose >= 0): print(self.STATE, opt_parser.name, epoch, ': ({}/{}:{})'.format(batch_i, len(room_idx_batches), num_node_per_batch), 'CAT Loss: {:.4f}, Acc_1: {:.4f}, Acc_3: {:.4f}, Acc_5: {:.4f},Dim Loss: {:.8f}, dim acc: {:.2f}'.format( loss_cat_per_batch / num_node_per_batch * 100.0, acc_cat_per_batch[1] / num_node_per_batch, acc_cat_per_batch[3] / num_node_per_batch, acc_cat_per_batch[5] / num_node_per_batch, loss_dim_per_batch / num_node_per_batch, dim_acc_per_batch / num_node_per_batch)) """ ================================================================== Report Part ================================================================== """ print('========================================================') print(self.STATE, epoch, ': ', 'CAT Loss: {:.4f}, Acc_1: {:.4f}, Acc_3: {:.4f}, Acc_5: {:.4f}, Dim Loss: {:.4f}, Dim acc: {:.4f}'.format( loss_cat_per_epoch / num_node_per_epoch * 100.0, acc_cat_per_epoch[1] / num_node_per_epoch, acc_cat_per_epoch[3] / num_node_per_epoch, acc_cat_per_epoch[5] / num_node_per_epoch, loss_dim_per_epoch / num_node_per_epoch, dim_acc_per_epoch / num_node_per_epoch)) print('========================================================') ''' write avg to log ''' if (opt_parser.write): self.writer.add_scalar('{}_LOSS_CAT'.format(self.STATE), loss_cat_per_epoch / num_node_per_epoch, epoch) self.writer.add_scalar('{}_ACC_CAT'.format(self.STATE), acc_cat_per_epoch[1] / num_node_per_epoch, epoch) self.writer.add_scalar('{}_ACC_3_CAT'.format(self.STATE), acc_cat_per_epoch[3] / num_node_per_epoch, epoch) self.writer.add_scalar('{}_ACC_5_CAT'.format(self.STATE), acc_cat_per_epoch[5] / num_node_per_epoch, epoch) self.writer.add_scalar('{}_LOSS_DIM'.format(self.STATE), loss_dim_per_epoch / num_node_per_epoch, epoch) ''' save model ''' if (not is_training): def save_model(save_type): torch.save({ 'full_enc_state_dict': self.full_enc.state_dict(), 'full_dec_state_dict': self.full_dec.state_dict(), 'full_enc_opt': self.opt['full_enc'].state_dict(), 'full_dec_opt': self.opt['full_dec'].state_dict(), 'epoch': epoch }, '{}/Entire_model_{}.pth'.format(opt_parser.outf, save_type)) # if model is better, save model checkpoint # min dim loss model if(loss_dim_per_epoch / num_node_per_epoch < self.MIN_DIM_LOSS): self.MIN_DIM_LOSS = loss_dim_per_epoch / num_node_per_epoch save_model('min_dim_loss') # max cat acc model (top-5 acc) if (acc_cat_per_epoch[5] / num_node_per_epoch > self.MAX_ACC): self.MAX_ACC = acc_cat_per_epoch[5] / num_node_per_epoch save_model('max_acc') # min cat loss model if (loss_cat_per_epoch / num_node_per_epoch < self.MIN_LOSS): self.MIN_LOSS = loss_cat_per_epoch / num_node_per_epoch save_model('min_loss') # always save the latest model save_model('last_epoch') return