Esempio n. 1
0
 def __init__(self, exp, debug=False):
     super().__init__()
     self.exp = exp
     """
     generate and printting trace
     """
     self.trace = []
     self.scratch = ScratchPad(exp, debug=debug)
     self.build()
def run_epoch(model, mode, cur_data, writer, path):
    global train_n_iter
    global same_n_iter
    global diff_n_iter
    if mode == 'train':
        random.shuffle(cur_data)
        model.train()
    elif mode == 'test':
        model.load_state_dict(torch.load(path))
        model.eval()
    else: 
        model.eval()

    epoch_def_loss = 0.0
    epoch_total_loss = 0.0
    epoch_pro_accs = 0.0
    epoch_ter_accs = 0.0
    epoch_arg_accs = [0.0, 0.0]
    epoch_step = 0
    start_time = time.time()

    criterion = NPI_LOSS()

    for idx in range(len(cur_data)):
        exp, trace = cur_data[idx]
        Pad = ScratchPad(exp)
        x, y = trace[:-1], trace[1:]
        h0 = torch.zeros((2, 1, 256)).to(device)
        
        step_def_loss = 0.0
        step_total_loss = 0.0
        pro_accs = 0.0
        ter_accs = 0.0
        arg_accs = [0.0, 0.0]

        for trace_idx in range(len(x)):
            (pro_in_name, pro_in_id), arg_in, ter_in = x[trace_idx]
            (pro_out_name, pro_out_id), arg_out, ter_out = y[trace_idx]

            Pad.execute(pro_in_id, arg_in)
            env_ft = Pad.get_env()
            env_ft = torch.from_numpy(env_ft).view(1, -1)  # the value of exp

            arg_in_ft = Pad.encode_args(arg_in)
            arg_in_ft = torch.from_numpy(arg_in_ft).view(1, -1)
            arg_out_ft = Pad.encode_args(arg_out)
            arg_out_ft = torch.from_numpy(arg_out_ft).view(1, -1)
            # arg_out_ft = np.array(arg_out)
            # arg_out_ft = torch.from_numpy(arg_out_ft).view(1, -1)


            pro_in_ft = np.array([pro_in_id])
            pro_in_ft = torch.from_numpy(pro_in_ft).view(1, -1)
            pro_out_ft = np.array([pro_out_id])
            pro_out_ft = torch.from_numpy(pro_out_ft).view(-1)

            ter_out_ft = [1] if ter_out else [0]
            ter_out_ft = np.array(ter_out_ft)
            ter_out_ft = torch.from_numpy(ter_out_ft).view(-1)

            arg_in_ft = arg_in_ft.to(device)
            arg_out_ft = arg_out_ft.to(device)
            pro_in_ft = pro_in_ft.to(device)
            pro_out_ft = pro_out_ft.to(device)
            ter_out_ft = ter_out_ft.to(device)
            env_ft = env_ft.to(device)

            initial = (trace_idx==0)
            pred, _ = npi(env_ft, arg_in_ft, pro_in_ft, h0, initial)
            gt = (ter_out_ft, pro_out_ft, arg_out_ft)

            default_loss, total_loss = criterion(pred, gt)
            pro_acc, ter_acc, arg_acc = criterion.metric(pred, gt)
            pro_accs += pro_acc
            ter_accs += ter_acc
            arg_accs[0] += arg_acc[0]
            arg_accs[1] += arg_acc[1]
            # arg_accs += arg_accs


            if mode == 'train':
                # if pro_out_id == 0 or pro_out_id == 1 or pro_out_id == 4:
                optimizer.zero_grad()
                total_loss.backward(retain_graph=True)
                optimizer.step()
                # else:
                #     optimizer.zero_grad()
                #     default_loss.backward(retain_graph=True)
                #     optimizer.step()

                step_def_loss += default_loss.item()
                step_total_loss += total_loss.item()

        if idx % 10 == 0:
            
            print("Epoch {0:02d} idx {1:03d} Default Step Loss {2:05f}, " \
                  "Total Step Loss {3:05f}, Term Acc: {4:03f}, Prog Acc: {5:03f}" \
                  .format(epoch, idx, step_def_loss / len(x), step_total_loss / len(x), ter_accs / len(x),
                          pro_accs / len(x)))

        if writer is not None:
            if mode == 'train':
                writer.add_scalar(mode + '/def_loss', step_def_loss / len(x), train_n_iter)
                writer.add_scalar(mode + '/total_loss', step_total_loss / len(x), train_n_iter)
                writer.add_scalar(mode + '/pro_accs', pro_accs / len(x), train_n_iter)
                writer.add_scalar(mode + '/ter_accs', ter_accs / len(x), train_n_iter)
                train_n_iter += 1
            elif mode == 'eval':
                writer.add_scalar(mode + '/def_loss', step_def_loss / len(x), same_n_iter)
                writer.add_scalar(mode + '/total_loss', step_total_loss / len(x), same_n_iter)
                writer.add_scalar(mode + '/pro_accs', pro_accs / len(x), same_n_iter)
                writer.add_scalar(mode + '/ter_accs', ter_accs / len(x), same_n_iter)
                same_n_iter += 1
            elif mode == 'test':
                writer.add_scalar(mode + '/def_loss', step_def_loss / len(x), diff_n_iter)
                writer.add_scalar(mode + '/total_loss', step_total_loss / len(x), diff_n_iter)
                writer.add_scalar(mode + '/pro_accs', pro_accs / len(x), diff_n_iter)
                writer.add_scalar(mode + '/ter_accs', ter_accs / len(x), diff_n_iter)
                diff_n_iter += 1

        epoch_def_loss += step_def_loss
        epoch_total_loss += step_total_loss
        epoch_pro_accs += pro_accs
        epoch_ter_accs += ter_accs
        epoch_arg_accs[0] += arg_accs[0]
        epoch_arg_accs[1] += arg_accs[1]
        epoch_step += len(x)


    end_time = time.time()
    epoch_time = end_time - start_time
    print("Mode: {0:s} For whole Epoch {1:02d}, Time Consum {2:05f} Default Step Loss {3:05f}, " \
          "Total Step Loss {4:05f}, Term Acc: {5:03f}, Prog Acc: {6:03f}, Arg0 Acc: {7:03f}, Arg1 Acc: {8:03f}"
          .format(mode, epoch, epoch_time, epoch_def_loss / epoch_step,
                  epoch_total_loss / epoch_step, epoch_ter_accs / epoch_step,
                  epoch_pro_accs / epoch_step, epoch_arg_accs[0] / epoch_step, epoch_arg_accs[1] / epoch_step))
    print('===============================')
    return (epoch_def_loss / epoch_step, epoch_total_loss / epoch_step,
            epoch_ter_accs / epoch_step, epoch_pro_accs / epoch_step)
Esempio n. 3
0
class Trace():
    def __init__(self, exp, debug=False):
        super().__init__()
        self.exp = exp
        """
        generate and printting trace
        """
        self.trace = []
        self.scratch = ScratchPad(exp, debug=debug)
        self.build()

    def build(self):
        self.trace.append(((REVPOLI, P[REVPOLI]), [], False))

        while not self.scratch.done():
            self.precedence()
            self.next()

    def precedence(self):
        self.trace.append(((PRECE, P[PRECE]), [], False))
        push_flg, value = self.scratch.prece()
        if push_flg == 0:
            self.push(value)
        elif push_flg == 1:
            self.pop()
            self.push(value)
        elif push_flg == 2:
            self.write_out(value)
        elif push_flg == 3:
            self.pop_loop()

    def write_out(self, value):
        self.trace.append(((WRITE, P[WRITE]), [WRITE_OUT, value], False))
        self.trace.append(((MOV_PTR, P[MOV_PTR]), [OUT_PTR, RIGHT], False))

        self.scratch.write_out(value)

    def push(self, value):
        self.trace.append(((MOV_PTR, P[MOV_PTR]), [STACK_PTR, RIGHT], False))
        self.trace.append(((WRITE, P[WRITE]), [STACK_PTR, value], False))

        self.scratch.push(value)

    def pop(self):
        # read
        s_top = self.scratch.pop_read()
        if I2A[s_top] not in [*'()']:
            self.write_out(s_top)
        # write
        self.trace.append(((WRITE, P[WRITE]), [STACK_PTR, 0], False))
        self.trace.append(((MOV_PTR, P[MOV_PTR]), [STACK_PTR, LEFT], False))

        self.scratch.pop()

    def pop_loop(self):
        s_top = self.scratch.pop_read()
        if I2A[s_top] not in [*'()']:
            self.write_out(s_top)

        self.trace.append(((WRITE, P[WRITE]), [STACK_PTR, 0], False))
        self.trace.append(((MOV_PTR, P[MOV_PTR]), [STACK_PTR, LEFT], False))

        self.scratch.pop()

        if I2A[s_top] != '(':
            self.pop_loop()

    def pop_all(self):
        s_top = self.scratch.pop_read()
        if s_top == 0:
            self.trace.append(((MOV_PTR, P[MOV_PTR]), [STACK_PTR, LEFT], True))
            return
        if I2A[s_top] not in [*'()']:
            self.write_out(s_top)

        self.trace.append(((WRITE, P[WRITE]), [STACK_PTR, 0], False))
        self.trace.append(((MOV_PTR, P[MOV_PTR]), [STACK_PTR, LEFT], False))

        self.scratch.pop()

        self.pop_all()

    def next(self):
        self.trace.append(((MOV_PTR, P[MOV_PTR]), [EXP_PTR, RIGHT], False))
        self.scratch.next()
        if self.scratch.done():
            # self.trace.append(((MOV_PTR, P[MOV_PTR]), [OUT_PTR, RIGHT], True))
            self.pop_all()
        else:
            pass