def __init__(self, exp, debug=False): super().__init__() self.exp = exp """ generate and printting trace """ self.trace = [] self.scratch = ScratchPad(exp, debug=debug) self.build()
def run_epoch(model, mode, cur_data, writer, path): global train_n_iter global same_n_iter global diff_n_iter if mode == 'train': random.shuffle(cur_data) model.train() elif mode == 'test': model.load_state_dict(torch.load(path)) model.eval() else: model.eval() epoch_def_loss = 0.0 epoch_total_loss = 0.0 epoch_pro_accs = 0.0 epoch_ter_accs = 0.0 epoch_arg_accs = [0.0, 0.0] epoch_step = 0 start_time = time.time() criterion = NPI_LOSS() for idx in range(len(cur_data)): exp, trace = cur_data[idx] Pad = ScratchPad(exp) x, y = trace[:-1], trace[1:] h0 = torch.zeros((2, 1, 256)).to(device) step_def_loss = 0.0 step_total_loss = 0.0 pro_accs = 0.0 ter_accs = 0.0 arg_accs = [0.0, 0.0] for trace_idx in range(len(x)): (pro_in_name, pro_in_id), arg_in, ter_in = x[trace_idx] (pro_out_name, pro_out_id), arg_out, ter_out = y[trace_idx] Pad.execute(pro_in_id, arg_in) env_ft = Pad.get_env() env_ft = torch.from_numpy(env_ft).view(1, -1) # the value of exp arg_in_ft = Pad.encode_args(arg_in) arg_in_ft = torch.from_numpy(arg_in_ft).view(1, -1) arg_out_ft = Pad.encode_args(arg_out) arg_out_ft = torch.from_numpy(arg_out_ft).view(1, -1) # arg_out_ft = np.array(arg_out) # arg_out_ft = torch.from_numpy(arg_out_ft).view(1, -1) pro_in_ft = np.array([pro_in_id]) pro_in_ft = torch.from_numpy(pro_in_ft).view(1, -1) pro_out_ft = np.array([pro_out_id]) pro_out_ft = torch.from_numpy(pro_out_ft).view(-1) ter_out_ft = [1] if ter_out else [0] ter_out_ft = np.array(ter_out_ft) ter_out_ft = torch.from_numpy(ter_out_ft).view(-1) arg_in_ft = arg_in_ft.to(device) arg_out_ft = arg_out_ft.to(device) pro_in_ft = pro_in_ft.to(device) pro_out_ft = pro_out_ft.to(device) ter_out_ft = ter_out_ft.to(device) env_ft = env_ft.to(device) initial = (trace_idx==0) pred, _ = npi(env_ft, arg_in_ft, pro_in_ft, h0, initial) gt = (ter_out_ft, pro_out_ft, arg_out_ft) default_loss, total_loss = criterion(pred, gt) pro_acc, ter_acc, arg_acc = criterion.metric(pred, gt) pro_accs += pro_acc ter_accs += ter_acc arg_accs[0] += arg_acc[0] arg_accs[1] += arg_acc[1] # arg_accs += arg_accs if mode == 'train': # if pro_out_id == 0 or pro_out_id == 1 or pro_out_id == 4: optimizer.zero_grad() total_loss.backward(retain_graph=True) optimizer.step() # else: # optimizer.zero_grad() # default_loss.backward(retain_graph=True) # optimizer.step() step_def_loss += default_loss.item() step_total_loss += total_loss.item() if idx % 10 == 0: print("Epoch {0:02d} idx {1:03d} Default Step Loss {2:05f}, " \ "Total Step Loss {3:05f}, Term Acc: {4:03f}, Prog Acc: {5:03f}" \ .format(epoch, idx, step_def_loss / len(x), step_total_loss / len(x), ter_accs / len(x), pro_accs / len(x))) if writer is not None: if mode == 'train': writer.add_scalar(mode + '/def_loss', step_def_loss / len(x), train_n_iter) writer.add_scalar(mode + '/total_loss', step_total_loss / len(x), train_n_iter) writer.add_scalar(mode + '/pro_accs', pro_accs / len(x), train_n_iter) writer.add_scalar(mode + '/ter_accs', ter_accs / len(x), train_n_iter) train_n_iter += 1 elif mode == 'eval': writer.add_scalar(mode + '/def_loss', step_def_loss / len(x), same_n_iter) writer.add_scalar(mode + '/total_loss', step_total_loss / len(x), same_n_iter) writer.add_scalar(mode + '/pro_accs', pro_accs / len(x), same_n_iter) writer.add_scalar(mode + '/ter_accs', ter_accs / len(x), same_n_iter) same_n_iter += 1 elif mode == 'test': writer.add_scalar(mode + '/def_loss', step_def_loss / len(x), diff_n_iter) writer.add_scalar(mode + '/total_loss', step_total_loss / len(x), diff_n_iter) writer.add_scalar(mode + '/pro_accs', pro_accs / len(x), diff_n_iter) writer.add_scalar(mode + '/ter_accs', ter_accs / len(x), diff_n_iter) diff_n_iter += 1 epoch_def_loss += step_def_loss epoch_total_loss += step_total_loss epoch_pro_accs += pro_accs epoch_ter_accs += ter_accs epoch_arg_accs[0] += arg_accs[0] epoch_arg_accs[1] += arg_accs[1] epoch_step += len(x) end_time = time.time() epoch_time = end_time - start_time print("Mode: {0:s} For whole Epoch {1:02d}, Time Consum {2:05f} Default Step Loss {3:05f}, " \ "Total Step Loss {4:05f}, Term Acc: {5:03f}, Prog Acc: {6:03f}, Arg0 Acc: {7:03f}, Arg1 Acc: {8:03f}" .format(mode, epoch, epoch_time, epoch_def_loss / epoch_step, epoch_total_loss / epoch_step, epoch_ter_accs / epoch_step, epoch_pro_accs / epoch_step, epoch_arg_accs[0] / epoch_step, epoch_arg_accs[1] / epoch_step)) print('===============================') return (epoch_def_loss / epoch_step, epoch_total_loss / epoch_step, epoch_ter_accs / epoch_step, epoch_pro_accs / epoch_step)
class Trace(): def __init__(self, exp, debug=False): super().__init__() self.exp = exp """ generate and printting trace """ self.trace = [] self.scratch = ScratchPad(exp, debug=debug) self.build() def build(self): self.trace.append(((REVPOLI, P[REVPOLI]), [], False)) while not self.scratch.done(): self.precedence() self.next() def precedence(self): self.trace.append(((PRECE, P[PRECE]), [], False)) push_flg, value = self.scratch.prece() if push_flg == 0: self.push(value) elif push_flg == 1: self.pop() self.push(value) elif push_flg == 2: self.write_out(value) elif push_flg == 3: self.pop_loop() def write_out(self, value): self.trace.append(((WRITE, P[WRITE]), [WRITE_OUT, value], False)) self.trace.append(((MOV_PTR, P[MOV_PTR]), [OUT_PTR, RIGHT], False)) self.scratch.write_out(value) def push(self, value): self.trace.append(((MOV_PTR, P[MOV_PTR]), [STACK_PTR, RIGHT], False)) self.trace.append(((WRITE, P[WRITE]), [STACK_PTR, value], False)) self.scratch.push(value) def pop(self): # read s_top = self.scratch.pop_read() if I2A[s_top] not in [*'()']: self.write_out(s_top) # write self.trace.append(((WRITE, P[WRITE]), [STACK_PTR, 0], False)) self.trace.append(((MOV_PTR, P[MOV_PTR]), [STACK_PTR, LEFT], False)) self.scratch.pop() def pop_loop(self): s_top = self.scratch.pop_read() if I2A[s_top] not in [*'()']: self.write_out(s_top) self.trace.append(((WRITE, P[WRITE]), [STACK_PTR, 0], False)) self.trace.append(((MOV_PTR, P[MOV_PTR]), [STACK_PTR, LEFT], False)) self.scratch.pop() if I2A[s_top] != '(': self.pop_loop() def pop_all(self): s_top = self.scratch.pop_read() if s_top == 0: self.trace.append(((MOV_PTR, P[MOV_PTR]), [STACK_PTR, LEFT], True)) return if I2A[s_top] not in [*'()']: self.write_out(s_top) self.trace.append(((WRITE, P[WRITE]), [STACK_PTR, 0], False)) self.trace.append(((MOV_PTR, P[MOV_PTR]), [STACK_PTR, LEFT], False)) self.scratch.pop() self.pop_all() def next(self): self.trace.append(((MOV_PTR, P[MOV_PTR]), [EXP_PTR, RIGHT], False)) self.scratch.next() if self.scratch.done(): # self.trace.append(((MOV_PTR, P[MOV_PTR]), [OUT_PTR, RIGHT], True)) self.pop_all() else: pass