def tested3():

    Inline750_vp_vs = np.zeros((xline_num, seismic.shape[1]))
    Inline750_seismic = seismic[(750-750)*xline_num:(751-750)*xline_num, :]

    model = BSsequential_net_seismic(BATCH_LEN).to(device)
    mode_patch = './model_file/pre_trained_network_model_model1/pre_trained_network_model_LN_2020_08_05_0%d.pth' % date_num
    model.load_state_dict(torch.load(mode_patch))
    for trace_number in range(0, xline_num):
        print(trace_number)

        temp_seismic = Inline750_seismic[trace_number, :]
        temp_seismic = torch.from_numpy(temp_seismic)
        temp_seismic = temp_seismic.float()
        temp_seismic = temp_seismic.view(1, -1)

        for num_batchlen in range(0, number):
            # 利用优选出来的井数据,井旁道,加上一个目标道组成网络的输入
            train_dataset = MyDataset2(temp_seismic[:, (num_batchlen * BATCH_LEN):((num_batchlen + 1) * BATCH_LEN)],
                                       temp_seismic[:, (num_batchlen * BATCH_LEN):((num_batchlen + 1) * BATCH_LEN)])
            train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=1, shuffle=True,
                                          drop_last=False)

            for itr, (train_dt, train_lable) in enumerate(train_dataloader):
                train_dt, train_lable = train_dt.to(device), train_lable.to(device)
                train_dt = train_dt.float()

                output = model(train_dt, BATCH_LEN)
                np_output = output.cpu().detach().numpy()
                Inline750_vp_vs[(trace_number * BATCH_SIZE):((trace_number + 1) * BATCH_SIZE), (num_batchlen * BATCH_LEN):((num_batchlen + 1) * BATCH_LEN)] = np_output

    pathmat = './LN_out/Inline750_vp_vs_model1_2020_08_05_0%d.mat' % date_num
    scipio.savemat(pathmat, {'Inline750_vp_vs_model1_2020_08_05_0%d' % date_num: Inline750_vp_vs})
def tested2():

    impedance_well = np.zeros(
        (train_well_seismic.shape[0], train_well_seismic.shape[1]))

    for num in range(0, number):

        test_dataset = MyDataset2(
            train_well_seismic[:, (num * BATCH_LEN):((num + 1) * BATCH_LEN)],
            train_well_seismic[:, (num * BATCH_LEN):((num + 1) * BATCH_LEN)])
        test_dataloader = DataLoader(test_dataset,
                                     batch_size=BATCH_SIZE,
                                     num_workers=5,
                                     shuffle=True,
                                     drop_last=False)

        model = ConvNet2_2().to(device)
        mode_patch = './model_file/pre_trained_network_model_model2/pre_trained_network_model_SMI_2020_04_17_01.pth'
        model.load_state_dict(torch.load(mode_patch))

        for itr, (test_dt, test_lable) in enumerate(test_dataloader):
            test_dt, test_lable = test_dt.to(device), test_lable.to(device)
            test_dt = test_dt.float()
            output = model(test_dt)
            np_output = output.cpu().detach().numpy()
            impedance_well[(itr * BATCH_SIZE):((itr + 1) * BATCH_SIZE),
                           (num * BATCH_LEN):((num + 1) *
                                              BATCH_LEN)] = np_output

    pathmat = './SMI_out/test_Impedance_model2_2020_04_17_01.mat'
    scipio.savemat(pathmat,
                   {'test_Impedance_model2_2020_04_17_01': impedance_well})
def tested4():
    impedance_inline99 = np.zeros(
        (test_Inline99_seismic.shape[0], test_Inline99_seismic.shape[1]))

    model = BSsequential_net_seismic().to(device)
    mode_patch = './model_file/pre_trained_network_model_model1/pre_trained_network_model_SMI_2020_08_03_0%d.pth' % date_num
    model.load_state_dict(torch.load(mode_patch))
    # num_params = 0
    # for param in model.parameters():
    #     num_params += param.numel()
    # print(num_params)
    for trace_number in range(0, test_Inline99_seismic.shape[0]):
        print(trace_number)

        temp_lable = torch.from_numpy(test_Inline99_seismic[trace_number, :])
        temp_lable = temp_lable.float()
        # temp_lable = temp_lable.to(device)
        temp_lable = temp_lable.view(1, -1)
        for num_batchlen in range(0, number):
            # 利用优选出来的井数据,井旁道,加上一个目标道组成网络的输入
            train_dataset = MyDataset2(
                temp_lable[:, (num_batchlen * BATCH_LEN):((num_batchlen + 1) *
                                                          BATCH_LEN)],
                temp_lable[:, (num_batchlen * BATCH_LEN):((num_batchlen + 1) *
                                                          BATCH_LEN)])
            train_dataloader = DataLoader(train_dataset,
                                          batch_size=BATCH_SIZE,
                                          num_workers=1,
                                          shuffle=True,
                                          drop_last=False)

            temp_seismic = test_Inline99_seismic[trace_number, (
                num_batchlen * BATCH_LEN):((num_batchlen + 1) * BATCH_LEN)]
            temp_seismic = torch.from_numpy(temp_seismic)
            temp_seismic = temp_seismic.float()
            temp_seismic = temp_seismic.to(device)
            temp_seismic = temp_seismic.view(1, -1)

            for itr, (train_dt, train_lable) in enumerate(train_dataloader):
                train_dt, train_lable = train_dt.to(device), train_lable.to(
                    device)
                train_dt = train_dt.float()

                output = model(train_dt)
                # train_dt = train_dt.view(num_well, -1)
                # output = tempval.mm(train_dt)
                np_output = output.cpu().detach().numpy()
                impedance_inline99[(trace_number *
                                    BATCH_SIZE):((trace_number + 1) *
                                                 BATCH_SIZE),
                                   (num_batchlen *
                                    BATCH_LEN):((num_batchlen + 1) *
                                                BATCH_LEN)] = np_output

    pathmat = './SMI_out/pre_In99_Impedance_model1_2020_08_03_0%d.mat' % date_num
    scipio.savemat(pathmat, {
        'pre_In99_Impedance_model1_2020_08_03_0%d' % date_num:
        impedance_inline99
    })
def pre_trained(judge):

    writer = SummaryWriter(
        log_dir='./loss/pre_train_loss_model1/pre_train_loss_SMI_2020_04_23_01'
    )

    if judge == 0:
        model = ConvNet1_3().to(device)
        temp = 10000000000000
        epoch_num = 1
    else:
        model = ConvNet1_3().to(device)
        mode_patch = './model_file/pre_trained_network_model_model1/pre_trained_network_model_SMI_2020_04_23_01.pth'
        model.load_state_dict(torch.load(mode_patch))
        path_temp = './Temporary_parameters/pre_temp_model1.mat'
        temp = scipio.loadmat(path_temp)
        temp = temp['temp'].item()
        path_epoch = './Temporary_parameters/pre_epoch_num_model1.mat'
        epoch_num = scipio.loadmat(path_epoch)
        epoch_num = epoch_num['epoch_num'].item() + 1

    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epoch_num, EPOCHS + 1):
        print(epoch)
        trace_number = np.random.randint(0, 142 * 110 * data_rate, 1)
        # print(trace_number)

        # 计算相关系数
        coef_seismic = np.zeros((105, Xline1_110_label_impedance.shape[1]))
        coef_seismic[0, :] = train1_110_seismic[trace_number, :]
        coef_seismic[1:105, :] = train_well_seismic[:, :]
        temp_coef = np.corrcoef(coef_seismic)

        # 优选出相关系数大于阈值并且半径范围内的井
        num = 0
        tempval = np.zeros(0)
        temp_train_well = np.zeros(0)
        temp_train_well_seisic = np.zeros(0)
        absCORcoef = np.abs(temp_coef[0, 1:105])
        for k in range(0, 104):
            if absCORcoef[k] > coefval:
                # 井数据的坐标
                wellxline = Xline_Inline_number[0, k]
                wellinline = Xline_Inline_number[1, k]
                # 目标地震数据的坐标
                seismicinline = np.mod(trace_number + 1, 142)
                seismicxline = (trace_number + 1 - seismicinline) / 142 + 1
                R = np.sqrt((seismicxline - wellxline) *
                            (seismicxline - wellxline) +
                            (seismicinline - wellinline) *
                            (seismicinline - wellinline))
                if R < Rval:
                    tempval = np.append(tempval, absCORcoef[k])
                    temp_train_well = np.append(temp_train_well,
                                                train_well[k, :])
                    temp_train_well_seisic = np.append(
                        temp_train_well_seisic, train_well_seismic[k, :])
                    num = num + 1

        if num < 1:
            num = 104
            tempval = np.zeros(0)
            for max_num in range(0, num):
                temp_tempval = max(absCORcoef)
                tempval = np.append(tempval, temp_tempval)
                for max_num2 in range(0, 104):
                    if temp_tempval == absCORcoef[max_num2]:
                        absCORcoef[max_num2] = 0
                        temp_train_well = np.append(temp_train_well,
                                                    train_well[max_num2, :])
                        temp_train_well_seisic = np.append(
                            temp_train_well_seisic,
                            train_well_seismic[max_num2, :])

        if num > 1:
            maxval = max(tempval)
            minval = min(tempval)
            max_minlen = maxval - minval

            tempval = (tempval - minval) / max_minlen
        else:
            tempval = 1
        valsum = sum(tempval)
        tempval = tempval / valsum
        tempval = torch.from_numpy(tempval)
        tempval = tempval.view(1, -1)
        tempval = tempval.float()
        tempval = tempval.to(device)

        temp_train_well = torch.from_numpy(temp_train_well)
        temp_train_well = temp_train_well.view(num, -1)
        # temp_train_well = tempval.mm(temp_train_well)

        temp_train_well = temp_train_well.float()
        temp_train_well = temp_train_well.to(device)
        # temp_train_well = temp_train_well.view(num, -1)

        # temp_train_well_seisic = torch.from_numpy(temp_train_well_seisic)
        # temp_train_well_seisic = temp_train_well_seisic.float()
        # temp_train_well_seisic = temp_train_well_seisic.to(device)
        # temp_train_well_seisic = temp_train_well_seisic.view(num, -1)
        # temp_seismic = torch.from_numpy(train1_75_seismic[trace_number, :])
        # temp_seismic = temp_seismic.float()
        # temp_seismic = temp_seismic.to(device)
        # temp_seismic = temp_seismic.view(1, -1)

        temp_lable = torch.from_numpy(
            Xline1_110_label_impedance[trace_number, :])
        temp_lable = temp_lable.float()
        temp_lable = temp_lable.to(device)
        temp_lable = temp_lable.view(1, -1)

        temp_train_seismic = train1_110_seismic[trace_number, :]
        temp_train_seismic = torch.from_numpy(temp_train_seismic)
        temp_train_seismic = temp_train_seismic.float()
        temp_train_seismic = temp_train_seismic.to(device)
        # 利用优选出来的井数据,井旁道,加上一个目标道组成网络的输入
        rand = np.random.randint(0, 60 - BATCH_LEN + 1, 1)
        train_dataset = MyDataset2(
            temp_train_well[:, rand[0]:rand[0] + BATCH_LEN],
            temp_lable[:, rand[0]:rand[0] + BATCH_LEN])
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=BATCH_SIZE,
                                      num_workers=1,
                                      shuffle=True,
                                      drop_last=False)
        epoch_loss = []

        for itr, (train_dt, train_lable) in enumerate(train_dataloader):
            train_dt, train_lable = train_dt.to(device), train_lable.to(device)
            train_dt = train_dt.float()
            train_lable = train_lable.float()

            model.train()
            optimizer.zero_grad()
            output = model(train_dt, tempval)
            # syn_seismic = syn_seismic_fun2(output, wavelet)
            # syn_seismic = syn_seismic.float()
            # loss = F.mse_loss(syn_seismic, temp_train_seismic) + F.mse_loss(output, train_lable)
            loss = F.mse_loss(output, train_lable)
            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.item())

        epoch_loss = np.sum(np.array(epoch_loss))
        writer.add_scalar('Train/MSE', epoch_loss, epoch)
        epoch_num = epoch
        print('Train set: Average loss: {:.15f}'.format(epoch_loss))
        if epoch_loss < temp:
            path = './model_file/pre_trained_network_model_model1/pre_trained_network_model_SMI_2020_04_23_01.pth'
            torch.save(model.state_dict(), path)
            temp = epoch_loss
        path_temp = './Temporary_parameters/pre_temp_model1.mat'
        path_epoch = './Temporary_parameters/pre_epoch_num_model1.mat'
        scipio.savemat(path_temp, {'temp': temp})
        scipio.savemat(path_epoch, {'epoch_num': epoch_num})
    writer.add_graph(model, (train_dt, tempval))
def tested2():
    impedance_xline76 = np.zeros(
        (test_Xline76_seismic.shape[0], test_Xline76_seismic.shape[1]))

    model = ConvNet1_3().to(device)
    mode_patch = './model_file/pre_trained_network_model_model1/pre_trained_network_model_SMI_2020_04_23_01.pth'
    model.load_state_dict(torch.load(mode_patch))
    for num_batchlen in range(0, number):
        for trace_number in range(0, test_Xline76_seismic.shape[0]):
            print(trace_number)

            # 计算相关系数
            coef_seismic = np.zeros((105, test_Xline76_seismic.shape[1]))
            coef_seismic[0, :] = test_Xline76_seismic[trace_number, :]
            coef_seismic[1:105, :] = train_well_seismic[:, :]
            temp_coef = np.corrcoef(coef_seismic)

            # 优选出相关系数大于阈值并且半径范围内的井
            num = 0
            tempval = np.zeros(0)
            temp_train_well = np.zeros(0)
            temp_train_well_seisic = np.zeros(0)
            absCORcoef = np.abs(temp_coef[0, 1:105])
            temp_trace_number = trace_number * 110 + 76
            for k in range(0, 104):
                if absCORcoef[k] > coefval:
                    # 井数据的坐标
                    wellxline = Xline_Inline_number[0, k]
                    wellinline = Xline_Inline_number[1, k]
                    # 目标地震数据的坐标
                    seismicxline = np.mod(temp_trace_number, 110)
                    seismicinline = (temp_trace_number -
                                     seismicxline) / 110 + 1
                    R = np.sqrt((seismicxline - wellxline) *
                                (seismicxline - wellxline) +
                                (seismicinline - wellinline) *
                                (seismicinline - wellinline))
                    if R < Rval:
                        tempval = np.append(tempval, absCORcoef[k])
                        temp_train_well = np.append(temp_train_well,
                                                    train_well[k, :])
                        temp_train_well_seisic = np.append(
                            temp_train_well_seisic, train_well_seismic[k, :])
                        num = num + 1

            if num < 1:
                num = 104
                tempval = np.zeros(0)
                for max_num in range(0, num):
                    temp_tempval = max(absCORcoef)
                    tempval = np.append(tempval, temp_tempval)
                    for max_num2 in range(0, 104):
                        if temp_tempval == absCORcoef[max_num2]:
                            absCORcoef[max_num2] = 0
                            temp_train_well = np.append(
                                temp_train_well, train_well[max_num2, :])
                            temp_train_well_seisic = np.append(
                                temp_train_well_seisic,
                                train_well_seismic[max_num2, :])
            if num > 1:
                maxval = max(tempval)
                minval = min(tempval)
                max_minlen = maxval - minval

                tempval = (tempval - minval) / max_minlen
            else:
                tempval = 1
            valsum = sum(tempval)
            tempval = tempval / valsum
            tempval = torch.from_numpy(tempval)
            tempval = tempval.view(1, -1)
            tempval = tempval.float()
            tempval = tempval.to(device)

            temp_train_well = torch.from_numpy(temp_train_well)
            temp_train_well = temp_train_well.view(num, -1)
            # temp_train_well = tempval.mm(temp_train_well)

            temp_train_well = temp_train_well.float()
            temp_train_well = temp_train_well.to(device)

            temp_lable = torch.from_numpy(
                test_Xline76_seismic[trace_number, :])
            temp_lable = temp_lable.float()
            temp_lable = temp_lable.to(device)
            temp_lable = temp_lable.view(1, -1)
            # 利用优选出来的井数据,井旁道,加上一个目标道组成网络的输入
            train_dataset = MyDataset2(
                temp_train_well[:,
                                (num_batchlen *
                                 BATCH_LEN):((num_batchlen + 1) * BATCH_LEN)],
                temp_lable[:, (num_batchlen * BATCH_LEN):((num_batchlen + 1) *
                                                          BATCH_LEN)])
            train_dataloader = DataLoader(train_dataset,
                                          batch_size=BATCH_SIZE,
                                          num_workers=1,
                                          shuffle=True,
                                          drop_last=False)

            for itr, (train_dt, train_lable) in enumerate(train_dataloader):
                train_dt, train_lable = train_dt.to(device), train_lable.to(
                    device)
                train_dt = train_dt.float()

                # model.train()
                output = model(train_dt, tempval)
                # train_dt = train_dt.view(tempval.size(1), -1)
                # output = tempval.mm(train_dt)
                np_output = output.cpu().detach().numpy()
                impedance_xline76[(trace_number *
                                   BATCH_SIZE):((trace_number + 1) *
                                                BATCH_SIZE),
                                  (num_batchlen *
                                   BATCH_LEN):((num_batchlen + 1) *
                                               BATCH_LEN)] = np_output

    pathmat = './SMI_out/pre_X76_Impedance_model1_2020_04_23_01.mat'
    scipio.savemat(
        pathmat, {'pre_X76_Impedance_model1_2020_04_23_01': impedance_xline76})
def pre_trained(judge):

    writer = SummaryWriter(log_dir='./loss/pre_train_loss_model1/pre_train_loss_LN_2020_08_05_0%d' % date_num)

    if judge == 0:
        model = BSsequential_net_seismic(BATCH_LEN).to(device)
        print("Total number of paramerters in networks is {}  ".format(sum(x.numel() for x in model.parameters())))
        device1 = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
        print(device1)
        # model.apply(weights_init)
        temp = 10000000000000
        epoch_num = 1
    else:
        model = BSsequential_net(BATCH_LEN).to(device)
        mode_patch = './model_file/pre_trained_network_model_model1/pre_trained_network_model_LN_2020_08_04_02.pth'
        model.load_state_dict(torch.load(mode_patch))
        temp = 10000000000000
        epoch_num = 1
    if is_consistent == 0:
        map_xline = np.zeros(0)
        map_inline = np.zeros(0)
    else:
        is_path = './SMI_out/map_number_2020_05_14_06.mat'
        Random_path = scipio.loadmat(is_path)
        map_xline = Random_path['map_xline']
        map_inline = Random_path['map_inline']
    count = 0  # 检验网络权重是否变化的计数器
    lr = 0.001  # 学习步长
    for epoch in range(epoch_num, EPOCHS+1):
        print(epoch, count)
        temp_weight = model.fc60.weight   # 检验网络权重是否变化的初始网络参数
        temp_a = torch.sum(temp_weight.data)
        # print(temp_weight)
        # temp_weight = model.lstm60
        # temp_a = torch.sum(temp_weight.weight_hh_l0.data) + torch.sum(temp_weight.weight_ih_l0.data)
        # print(a)

        if np.mod(epoch + 1, 200) == 0:
            lr = lr * 0.99
        optimizer = optim.Adam(model.parameters(), lr=lr)
        if is_consistent == 1:
            trace_number = np.int(map_inline[0, epoch-1]*xline_num+map_xline[0, epoch-1])
        else:
            temp_1 = np.random.randint(0, 501, 1)
            temp_2 = np.random.randint(0, 631, 1)
            trace_number = temp_1*631+temp_2
            map_inline = np.append(map_inline, temp_1)
            map_xline = np.append(map_xline, temp_2)

        temp_train_seismic = seismic[trace_number, :]
        temp_train_seismic = torch.from_numpy(temp_train_seismic)
        temp_train_seismic = temp_train_seismic.float()
        temp_train_seismic = temp_train_seismic.view(1, -1)
        temp_lable = torch.from_numpy(vp_vs_lable[trace_number, :])
        temp_lable = temp_lable.float()
        temp_lable = temp_lable.view(1, -1)
        for num_rand in range(0, number):
            # 利用优选出来的井数据,井旁道,加上一个目标道组成网络的输入
            rand = np.random.randint(0, seismic.shape[1] - BATCH_LEN + 1, 1)
            train_dataset = MyDataset2(temp_train_seismic[:, rand[0]:rand[0] + BATCH_LEN], temp_lable[:, rand[0]:rand[0] + BATCH_LEN])
            train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=1, shuffle=True, drop_last=False)
            epoch_loss = []

            for itr, (train_dt, train_lable) in enumerate(train_dataloader):
                train_dt, train_lable = train_dt.to(device), train_lable.to(device)
                train_dt = train_dt.float()
                train_lable = train_lable.float()

                model.train()
                optimizer.zero_grad()
                output = model(train_dt, BATCH_LEN)
                if is_synseismic == 1:
                    syn_seismic = syn_seismic_fun2(output, wavelet)
                    syn_seismic = syn_seismic.float()
                    loss = F.mse_loss(syn_seismic, temp_train_seismic) + F.mse_loss(output, train_lable)
                else:
                    loss = F.mse_loss(output, train_lable)

                loss.backward()
                optimizer.step()

                epoch_loss.append(loss.item())

        temp_b = torch.sum(model.fc60.weight.data)
        # temp_b = torch.sum(model.lstm60.weight_hh_l0.data) + torch.sum(model.lstm60.weight_ih_l0.data)
        # print(b)
        if temp_a == temp_b:
            count = count + 1
        else:
            count = 0
        if count > 50:
            break

        epoch_loss = np.sum(np.array(epoch_loss))
        writer.add_scalar('Train/MSE', epoch_loss, epoch)
        epoch_num = epoch
        print('Train set: Average loss: {:.15f}'.format(epoch_loss))
        if epoch_loss < temp:
            path = './model_file/pre_trained_network_model_model1/pre_trained_network_model_LN_2020_08_05_0%d.pth' % date_num
            torch.save(model.state_dict(), path)
        path_loss = './Temporary_parameters/pre_temp_model1.mat'
        path_epoch = './Temporary_parameters/pre_epoch_num_model1.mat'
        scipio.savemat(path_loss, {'epoch_loss': epoch_loss})
        scipio.savemat(path_epoch, {'epoch_num': epoch_num})
    if is_consistent == 0:
        pathmat = './LN_out/map_number_2020_08_05_0%d.mat' % date_num
        scipio.savemat(pathmat, {'map_xline': map_xline, 'map_inline': map_inline})
    writer.add_graph(model, (train_dt, torch.tensor(BATCH_LEN)))
    writer.close()
def pre_trained(judge):

    writer = SummaryWriter(
        log_dir='./loss/pre_train_loss_model1/pre_train_loss_LN_2020_08_08_0%d'
        % date_num)

    if judge == 0:
        model = BSsequential_net(BATCH_LEN, num_well).to(device)
        print("Total number of paramerters in networks is {}  ".format(
            sum(x.numel() for x in model.parameters())))
        device1 = torch.device("cuda:0" if (
            torch.cuda.is_available()) else "cpu")
        print(device1)
        # model.apply(weights_init)
        temp = 10000000000000
        epoch_num = 1
    else:
        model = BSsequential_net(BATCH_LEN, num_well).to(device)
        mode_patch = './model_file/pre_trained_network_model_model1/pre_trained_network_model_LN_2020_08_07_0%d.pth' % date_num
        model.load_state_dict(torch.load(mode_patch))
        device1 = torch.device("cuda:0" if (
            torch.cuda.is_available()) else "cpu")
        print(device1)
        temp = 10000000000000
        epoch_num = 50001
    if is_consistent == 0:
        map_xline = np.zeros(0)
        map_inline = np.zeros(0)
    else:
        is_path = './SMI_out/map_number_2020_05_14_06.mat'
        Random_path = scipio.loadmat(is_path)
        map_xline = Random_path['map_xline']
        map_inline = Random_path['map_inline']
    count = 0  # 检验网络权重是否变化的计数器
    lr = 0.001  # 学习步长
    temp_epoch = 1
    for epoch in range(epoch_num, EPOCHS + 1):
        temp_weight = model.fc60.weight  # 检验网络权重是否变化的初始网络参数
        temp_a = torch.sum(temp_weight.data)
        # print(temp_weight)
        # temp_weight = model.lstm60
        # temp_a = torch.sum(temp_weight.weight_hh_l0.data) + torch.sum(temp_weight.weight_ih_l0.data)
        # print(a)

        lr = lr * 0.9
        optimizer = optim.Adam(model.parameters(), lr=lr)

        for num_rand in range(0, number):
            rand = np.random.randint(0, train_well.shape[1] - BATCH_LEN + 1, 1)

            train_seisic_well_number = np.zeros(
                (totall_number, BATCH_LEN + num_well))
            train_seisic_well_number[:,
                                     0:BATCH_LEN] = seisic_well_number[:, rand[
                                         0]:rand[0] + BATCH_LEN]
            train_seisic_well_number[:, BATCH_LEN:BATCH_LEN +
                                     num_well] = seisic_well_number[:,
                                                                    seisic_well_number
                                                                    .shape[1] -
                                                                    num_well:
                                                                    seisic_well_number
                                                                    .shape[1]]
            train_seisic_well_number = torch.from_numpy(
                train_seisic_well_number)
            train_seisic_well_number = train_seisic_well_number.float()

            train_vp_vs_lable = torch.from_numpy(
                vp_vs_lable[:, rand[0]:rand[0] + BATCH_LEN])
            train_vp_vs_lable = train_vp_vs_lable.float()

            # 利用优选出来的井数据,井旁道,加上一个目标道组成网络的输入

            train_dataset = MyDataset2(train_seisic_well_number,
                                       train_vp_vs_lable)
            train_dataloader = DataLoader(train_dataset,
                                          batch_size=BATCH_SIZE,
                                          num_workers=BATCH_SIZE,
                                          shuffle=True,
                                          drop_last=False)

            for itr, (train_dt, train_lable) in enumerate(train_dataloader):
                epoch_loss = []

                temp_train_well = train_well[
                    train_dt[:, BATCH_LEN:train_dt.shape[1]].int(),
                    rand[0]:rand[0] + BATCH_LEN]
                temp_train_well = torch.from_numpy(temp_train_well)
                train_dt, train_lable = train_dt[:, 0:BATCH_LEN].to(
                    device), train_lable.to(device)
                temp_train_well = temp_train_well.to(device)
                train_dt = train_dt.float()
                train_lable = train_lable.float()
                temp_train_well = temp_train_well.float()

                model.train()
                optimizer.zero_grad()
                output = model(temp_train_well, train_dt)
                if is_synseismic == 1:
                    syn_seismic = syn_seismic_fun2(output, wavelet)
                    syn_seismic = syn_seismic.float()
                    loss = F.mse_loss(syn_seismic,
                                      temp_train_seismic) + F.mse_loss(
                                          output, train_lable)
                else:
                    loss = F.mse_loss(output, train_lable) + F.l1_loss(
                        output, train_lable)

                loss.backward()
                optimizer.step()

                epoch_loss.append(loss.item())

                temp_b = torch.sum(model.fc60.weight.data)
                # temp_b = torch.sum(model.lstm60.weight_hh_l0.data) + torch.sum(model.lstm60.weight_ih_l0.data)
                # print(b)
                if temp_a == temp_b:
                    count = count + 1
                else:
                    count = 0
                if count > 100:
                    break

                epoch_loss = np.sum(np.array(epoch_loss))
                writer.add_scalar('Train/MSE', epoch_loss, temp_epoch)
                temp_epoch = temp_epoch + 1
                print("epoch-itr-count:", epoch, itr, count,
                      'Train set: Average loss: {:.15f}'.format(epoch_loss))
                path = './model_file/pre_trained_network_model_model1/pre_trained_network_model_LN_2020_08_08_0%d.pth' % date_num
                torch.save(model.state_dict(), path)
    # writer.add_graph(model, (train_dt, temp_train_seismic))
    writer.close()
def tested4():
    impedance_inline99 = np.zeros((test_Inline99_seismic.shape[0], test_Inline99_seismic.shape[1]))

    model = BSsequential_net_lstm().to(device)
    mode_patch = './model_file/pre_trained_network_model_model1/pre_trained_network_model_SMI_2020_07_31_0%d.pth' % date_num
    model.load_state_dict(torch.load(mode_patch))
    # num_params = 0
    # for param in model.parameters():
    #     num_params += param.numel()
    # print(num_params)
    for trace_number in range(0, test_Inline99_seismic.shape[0]):
        print(trace_number)

        # 计算相关系数
        coef_seismic = np.zeros((105, Xline1_110_label_impedance.shape[1]))
        coef_seismic[0, :] = test_Inline99_seismic[trace_number, :]
        coef_seismic[1:105, :] = train_well_seismic[:, :]
        temp_coef = np.corrcoef(coef_seismic)

        # 优选出相关系数大于阈值并且半径范围内的井
        tempval_1 = np.zeros(0)
        temp_train_well_1 = np.zeros(0)
        temp_train_well_seisic_1 = np.zeros(0)
        absCORcoef = np.abs(temp_coef[0, 1:105])
        if which_choose_well == 1:
            num = 0
            for k in range(0, 104):
                if absCORcoef[k] > coefval:
                    # 井数据的坐标
                    wellxline = Xline_Inline_number[0, k]
                    wellinline = Xline_Inline_number[1, k]
                    # 目标地震数据的坐标
                    seismicinline = np.mod(trace_number + 1, 142)
                    seismicxline = (trace_number + 1 - seismicinline) / 142 + 1
                    R = np.sqrt(
                        (seismicxline - wellxline) * (seismicxline - wellxline) + (seismicinline - wellinline) * (
                                seismicinline - wellinline))
                    if R < Rval:
                        tempval_1 = np.append(tempval_1, absCORcoef[k])
                        temp_train_well_1 = np.append(temp_train_well_1, train_well[k, :])
                        temp_train_well_seisic_1 = np.append(temp_train_well_seisic_1, train_well_seismic[k, :])
                        num = num + 1

            temp_train_well = np.zeros(0)
            temp_train_well_seisic = np.zeros(0)
            if num < num_well:
                num = num_well
                tempval = np.zeros(0)
                for max_num in range(0, num):
                    temp_tempval = max(absCORcoef)
                    tempval = np.append(tempval, temp_tempval)
                    for max_num2 in range(0, 104):
                        if temp_tempval == absCORcoef[max_num2]:
                            absCORcoef[max_num2] = 0
                            temp_train_well = np.append(temp_train_well, train_well[max_num2, :])
                            temp_train_well_seisic = np.append(temp_train_well_seisic, train_well_seismic[max_num2, :])
            else:
                tempval = np.zeros(0)
                temp_train_well_1 = torch.from_numpy(temp_train_well_1)
                temp_train_well_1 = temp_train_well_1.view(num, -1)
                temp_train_well_1 = temp_train_well_1.cpu().detach().numpy()
                temp_train_well_seisic_1 = torch.from_numpy(temp_train_well_seisic_1)
                temp_train_well_seisic_1 = temp_train_well_seisic_1.view(num, -1)
                temp_train_well_seisic_1 = temp_train_well_seisic_1.cpu().detach().numpy()
                for max_num in range(0, num_well):
                    temp_tempval = max(tempval_1)
                    tempval = np.append(tempval, temp_tempval)
                    for max_num2 in range(0, num):
                        if temp_tempval == tempval_1[max_num2]:
                            tempval_1[max_num2] = 0
                            temp_train_well = np.append(temp_train_well, temp_train_well_1[max_num2, :])
                            temp_train_well_seisic = np.append(temp_train_well_seisic,
                                                               temp_train_well_seisic_1[max_num2, :])
        else:
            num = num_well
            tempval = np.zeros(0)
            temp_train_well = np.zeros(0)
            temp_train_well_seisic = np.zeros(0)
            for max_num in range(0, num):
                temp_tempval = max(absCORcoef)
                tempval = np.append(tempval, temp_tempval)
                for max_num2 in range(0, 104):
                    if temp_tempval == absCORcoef[max_num2]:
                        absCORcoef[max_num2] = 0
                        temp_train_well = np.append(temp_train_well, train_well[max_num2, :])
                        temp_train_well_seisic = np.append(temp_train_well_seisic, train_well_seismic[max_num2, :])

        num = num_well
        maxval = max(tempval)
        minval = min(tempval)
        max_minlen = maxval - minval
        tempval = (tempval - minval) / max_minlen
        valsum = sum(tempval)
        tempval = tempval / valsum

        tempval = torch.from_numpy(tempval)
        tempval = tempval.view(1, -1)
        tempval = tempval.float()
        tempval = tempval.to(device)

        temp_train_well = torch.from_numpy(temp_train_well)
        temp_train_well = temp_train_well.view(num, -1)
        temp_train_well = temp_train_well.float()
        # temp_train_well = temp_train_well.to(device)

        temp_lable = torch.from_numpy(test_Inline99_seismic[trace_number, :])
        temp_lable = temp_lable.float()
        # temp_lable = temp_lable.to(device)
        temp_lable = temp_lable.view(1, -1)
        for num_batchlen in range(0, number):
            # 利用优选出来的井数据,井旁道,加上一个目标道组成网络的输入
            train_dataset = MyDataset2(temp_train_well[:, (num_batchlen * BATCH_LEN):((num_batchlen + 1) * BATCH_LEN)],
                                       temp_lable[:, (num_batchlen * BATCH_LEN):((num_batchlen + 1) * BATCH_LEN)])
            train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=1, shuffle=True,
                                          drop_last=False)

            temp_seismic = test_Inline99_seismic[trace_number, (num_batchlen * BATCH_LEN):((num_batchlen + 1) * BATCH_LEN)]
            temp_seismic = torch.from_numpy(temp_seismic)
            temp_seismic = temp_seismic.float()
            temp_seismic = temp_seismic.to(device)
            temp_seismic = temp_seismic.view(1, -1)

            for itr, (train_dt, train_lable) in enumerate(train_dataloader):
                train_dt, train_lable = train_dt.to(device), train_lable.to(device)
                train_dt = train_dt.float()

                output = model(train_dt, temp_seismic)
                # train_dt = train_dt.view(num_well, -1)
                # output = tempval.mm(train_dt)
                np_output = output.cpu().detach().numpy()
                impedance_inline99[(trace_number * BATCH_SIZE):((trace_number + 1) * BATCH_SIZE), (num_batchlen * BATCH_LEN):((num_batchlen + 1) * BATCH_LEN)] = np_output

    pathmat = './SMI_out/pre_In99_Impedance_model1_2020_07_31_0%d.mat' % date_num
    scipio.savemat(pathmat, {'pre_In99_Impedance_model1_2020_07_31_0%d' % date_num: impedance_inline99})
def pre_trained(judge):

    writer = SummaryWriter(log_dir='./loss/pre_train_loss_model1/pre_train_loss_SMI_2020_07_31_0%d' % date_num)

    if judge == 0:
        model = BSsequential_net_lstm().to(device)
        print("Total number of paramerters in networks is {}  ".format(sum(x.numel() for x in model.parameters())))
        # model.apply(weights_init)
        temp = 10000000000000
        epoch_num = 1
    else:
        model = BSsequential_net_lstm().to(device)
        mode_patch = './model_file/pre_trained_network_model_model1/pre_trained_network_model_SMI_2020_07_24_02.pth'
        model.load_state_dict(torch.load(mode_patch))
        temp = 10000000000000
        epoch_num = 1
        # path_temp = './Temporary_parameters/pre_temp_model1.mat'
        # temp = scipio.loadmat(path_temp)
        # temp = temp['temp'].item()
        # path_epoch = './Temporary_parameters/pre_epoch_num_model1.mat'
        # epoch_num = scipio.loadmat(path_epoch)
        # epoch_num = epoch_num['epoch_num'].item()+1
    if is_consistent == 0:
        map_xline = np.zeros(0)
        map_inline = np.zeros(0)
    else:
        is_path = './SMI_out/map_number_2020_05_14_06.mat'
        Random_path = scipio.loadmat(is_path)
        map_xline = Random_path['map_xline']
        map_inline = Random_path['map_inline']
    count = 0  # 检验网络权重是否变化的计数器
    lr = 0.001  # 学习步长
    for epoch in range(epoch_num, EPOCHS+1):
        print(epoch, count)
        # temp_weight = model.fc60.weight   # 检验网络权重是否变化的初始网络参数
        # temp_a = torch.sum(temp_weight.data)
        # print(temp_weight)
        temp_weight = model.lstm60
        temp_a = torch.sum(temp_weight.weight_hh_l0.data) + torch.sum(temp_weight.weight_ih_l0.data)
        # print(a)

        if np.mod(epoch + 1, 200) == 0:
            lr = lr * 0.99
        optimizer = optim.Adam(model.parameters(), lr=lr)
        if is_consistent == 1:
            trace_number = np.int(map_xline[0, epoch-1]*142+map_inline[0, epoch-1])
        else:
            temp_1 = np.random.randint(0, 142, 1)  # 29
            temp_2 = np.random.randint(0, 110, 1)  # 22
            trace_number = temp_2*142+temp_1
            map_xline = np.append(map_xline, temp_2)
            map_inline = np.append(map_inline, temp_1)
            # trace_number = temp_2*5*142+temp_1*5
            # map_xline = np.append(map_xline, temp_2 * 5)
            # map_inline = np.append(map_inline, temp_1 * 5)
        # trace_number = np.random.randint(0, 142*110*data_rate, 1)
        # print(trace_number)

        # 计算相关系数
        coef_seismic = np.zeros((105, Xline1_110_label_impedance.shape[1]))
        coef_seismic[0, :] = train1_110_seismic[trace_number, :]
        coef_seismic[1:105, :] = train_well_seismic[:, :]
        temp_coef = np.corrcoef(coef_seismic)

        # 优选出相关系数大于阈值并且半径范围内的井
        tempval_1 = np.zeros(0)
        temp_train_well_1 = np.zeros(0)
        temp_train_well_seisic_1 = np.zeros(0)
        absCORcoef = np.abs(temp_coef[0, 1:105])
        if which_choose_well == 1:
            num = 0
            for k in range(0, 104):
                if absCORcoef[k] > coefval:
                    # 井数据的坐标
                    wellxline = Xline_Inline_number[0, k]
                    wellinline = Xline_Inline_number[1, k]
                    # 目标地震数据的坐标
                    seismicinline = np.mod(trace_number + 1, 142)
                    seismicxline = (trace_number + 1 - seismicinline) / 142 + 1
                    R = np.sqrt((seismicxline - wellxline) * (seismicxline - wellxline) + (seismicinline - wellinline) * (
                            seismicinline - wellinline))
                    if R < Rval:
                        tempval_1 = np.append(tempval_1, absCORcoef[k])
                        temp_train_well_1 = np.append(temp_train_well_1, train_well[k, :])
                        temp_train_well_seisic_1 = np.append(temp_train_well_seisic_1, train_well_seismic[k, :])
                        num = num + 1

            temp_train_well = np.zeros(0)
            temp_train_well_seisic = np.zeros(0)
            if num < num_well:
                num = num_well
                tempval = np.zeros(0)
                for max_num in range(0, num):
                    temp_tempval = max(absCORcoef)
                    tempval = np.append(tempval, temp_tempval)
                    for max_num2 in range(0, 104):
                        if temp_tempval == absCORcoef[max_num2]:
                            absCORcoef[max_num2] = 0
                            temp_train_well = np.append(temp_train_well, train_well[max_num2, :])
                            temp_train_well_seisic = np.append(temp_train_well_seisic, train_well_seismic[max_num2, :])
            else:
                tempval = np.zeros(0)
                temp_train_well_1 = torch.from_numpy(temp_train_well_1)
                temp_train_well_1 = temp_train_well_1.view(num, -1)
                temp_train_well_1 = temp_train_well_1.cpu().detach().numpy()
                temp_train_well_seisic_1 = torch.from_numpy(temp_train_well_seisic_1)
                temp_train_well_seisic_1 = temp_train_well_seisic_1.view(num, -1)
                temp_train_well_seisic_1 = temp_train_well_seisic_1.cpu().detach().numpy()
                for max_num in range(0, num_well):
                    temp_tempval = max(tempval_1)
                    tempval = np.append(tempval, temp_tempval)
                    for max_num2 in range(0, num):
                        if temp_tempval == tempval_1[max_num2]:
                            tempval_1[max_num2] = 0
                            temp_train_well = np.append(temp_train_well, temp_train_well_1[max_num2, :])
                            temp_train_well_seisic = np.append(temp_train_well_seisic, temp_train_well_seisic_1[max_num2, :])
        else:
            num = num_well
            tempval = np.zeros(0)
            temp_train_well = np.zeros(0)
            temp_train_well_seisic = np.zeros(0)
            for max_num in range(0, num):
                temp_tempval = max(absCORcoef)
                tempval = np.append(tempval, temp_tempval)
                for max_num2 in range(0, 104):
                    if temp_tempval == absCORcoef[max_num2]:
                        absCORcoef[max_num2] = 0
                        temp_train_well = np.append(temp_train_well, train_well[max_num2, :])
                        temp_train_well_seisic = np.append(temp_train_well_seisic, train_well_seismic[max_num2, :])

        num = num_well
        maxval = max(tempval)
        minval = min(tempval)
        max_minlen = maxval - minval
        tempval = (tempval - minval) / max_minlen
        valsum = sum(tempval)
        tempval = tempval / valsum

        tempval = torch.from_numpy(tempval)
        tempval = tempval.view(1, -1)
        tempval = tempval.float()
        tempval = tempval.to(device)

        temp_train_well = torch.from_numpy(temp_train_well)
        temp_train_well = temp_train_well.view(num, -1)
        temp_train_well = temp_train_well.float()
        # temp_train_well = temp_train_well.to(device)
        # temp_train_well = temp_train_well.view(num, -1)

        # temp_train_well_seisic = torch.from_numpy(temp_train_well_seisic)
        # temp_train_well_seisic = temp_train_well_seisic.float()
        # temp_train_well_seisic = temp_train_well_seisic.to(device)
        # temp_train_well_seisic = temp_train_well_seisic.view(num, -1)
        # temp_seismic = torch.from_numpy(train1_75_seismic[trace_number, :])
        # temp_seismic = temp_seismic.float()
        # temp_seismic = temp_seismic.to(device)
        # temp_seismic = temp_seismic.view(1, -1)

        temp_lable = torch.from_numpy(Xline1_110_label_impedance[trace_number, :])
        temp_lable = temp_lable.float()
        # temp_lable = temp_lable.to(device)
        temp_lable = temp_lable.view(1, -1)
        # for rand in range(0, 60 - BATCH_LEN + 1):
        for num_rand in range(0, number):
            rand = np.random.randint(0, 60 - BATCH_LEN + 1, 1)
            temp_train_seismic = train1_110_seismic[trace_number, rand[0]:rand[0] + BATCH_LEN]
            temp_train_seismic = torch.from_numpy(temp_train_seismic)
            temp_train_seismic = temp_train_seismic.float()
            temp_train_seismic = temp_train_seismic.to(device)
            temp_train_seismic = temp_train_seismic.view(1, -1)

            # 利用优选出来的井数据,井旁道,加上一个目标道组成网络的输入

            train_dataset = MyDataset2(temp_train_well[:, rand[0]:rand[0] + BATCH_LEN], temp_lable[:, rand[0]:rand[0] + BATCH_LEN])
            train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=1, shuffle=True, drop_last=False)
            epoch_loss = []

            for itr, (train_dt, train_lable) in enumerate(train_dataloader):
                train_dt, train_lable = train_dt.to(device), train_lable.to(device)
                train_dt = train_dt.float()
                train_lable = train_lable.float()

                model.train()
                optimizer.zero_grad()
                output = model(train_dt, temp_train_seismic)
                if is_synseismic == 1:
                    syn_seismic = syn_seismic_fun2(output, wavelet)
                    syn_seismic = syn_seismic.float()
                    loss = F.mse_loss(syn_seismic, temp_train_seismic) + F.mse_loss(output, train_lable)
                else:
                    loss = F.mse_loss(output, train_lable)

                loss.backward()
                optimizer.step()

                # print(model.conv1.weight)
                # print(model.conv2.weight)
                # print(model.lstm.weight)
                # print(model.fc1.weight.data[:, 0])

                epoch_loss.append(loss.item())

        # temp_b = torch.sum(model.fc60.weight.data)
        temp_b = torch.sum(model.lstm60.weight_hh_l0.data) + torch.sum(model.lstm60.weight_ih_l0.data)
        # print(b)
        if temp_a == temp_b:
            count = count + 1
        else:
            count = 0
        if count > 50:
            break

        epoch_loss = np.sum(np.array(epoch_loss))
        writer.add_scalar('Train/MSE', epoch_loss, epoch)
        epoch_num = epoch
        print('Train set: Average loss: {:.15f}'.format(epoch_loss))
        if epoch_loss < temp:
            path = './model_file/pre_trained_network_model_model1/pre_trained_network_model_SMI_2020_07_31_0%d.pth' % date_num
            torch.save(model.state_dict(), path)
        path_loss = './Temporary_parameters/pre_temp_model1.mat'
        path_epoch = './Temporary_parameters/pre_epoch_num_model1.mat'
        scipio.savemat(path_loss, {'epoch_loss': epoch_loss})
        scipio.savemat(path_epoch, {'epoch_num': epoch_num})
    if is_consistent == 0:
        pathmat = './SMI_out/map_number_2020_07_31_0%d.mat' % date_num
        scipio.savemat(pathmat, {'map_xline': map_xline, 'map_inline': map_inline})
    writer.add_graph(model, (train_dt, temp_train_seismic))
    writer.close()
def inversion(judge):
    trace_number = m0.shape[0]
    point_number = m0.shape[1]
    impedance = np.zeros((m0.shape[0], m0.shape[1]))
    wavelet = torch.from_numpy(wavele)
    wavelet = wavelet.float()
    wavelet = wavelet.view(wavelet.size(1))
    wavelet = wavelet.to(device)
    for k in range(0, trace_number):
        writer = SummaryWriter(
            log_dir=
            './loss/inversion_loss_SMI2020_04_12_02/inversion_loss_SMI_%d' % k)
        train_dataset = MyDataset2(
            m0[k, :].reshape(1, point_number),
            train_seismic[k, :].reshape(1, point_number))
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=BATCH_SIZE,
                                      num_workers=1,
                                      shuffle=True,
                                      drop_last=False)

        if judge == 0:
            model = ConvNet3().to(device)
            temp = 10000000000000
            epoch_num = 1
        else:
            model = ConvNet3().to(device)
            mode_patch = './model_file/inversion_model_SMI2020_04_12_02/inversion_model_SMI_%d.pth' % k
            model.load_state_dict(torch.load(mode_patch))
            path_temp = './Temporary_parameters/inversion_temp.mat'
            temp = scipio.loadmat(path_temp)
            temp = temp['temp'].item()
            path_epoch = './Temporary_parameters/inversion_epoch_num.mat'
            epoch_num = scipio.loadmat(path_epoch)
            epoch_num = epoch_num['epoch_num'].item() + 1

        optimizer = optim.Adam(model.parameters(), lr=lr)

        for epoch in range(epoch_num, EPOCHS + 1):

            print(k, epoch)
            epoch_loss = []

            for itr, (train_dt, train_lable) in enumerate(train_dataloader):
                train_dt, train_lable = train_dt.to(device), train_lable.to(
                    device)
                train_dt = train_dt.float()
                train_lable = train_lable.float()

                model.train()
                optimizer.zero_grad()
                output = model(train_dt)
                syn_seismic = syn_seismic_fun(output, wavelet)
                syn_seismic = syn_seismic.float()
                loss = F.mse_loss(syn_seismic, train_lable) + F.mse_loss(
                    output, train_dt)
                loss.backward()
                optimizer.step()

                epoch_loss.append(loss.item())
            writer.add_graph(model, (train_dt, ))

            epoch_loss = np.sum(np.array(epoch_loss))
            writer.add_scalar('Train/MSE', epoch_loss, epoch)

            epoch_num = epoch
            print('Train set: Average loss: {:.15f}'.format(epoch_loss))
            if epoch_loss < temp:
                path = './model_file/inversion_model_SMI2020_04_12_02/inversion_model_SMI_%d.pth' % k
                torch.save(model.state_dict(), path)
                temp = epoch_loss
                np_output = output.cpu().detach().numpy()
                impedance[k, :] = np_output
                writer.close()
                pathmat = './SMI_out/inversion_Impedance2020_04_12_02.mat'
                scipio.savemat(pathmat,
                               {'inversion_Impedance2020_04_12_02': impedance})
            path_temp = './Temporary_parameters/inversion_temp.mat'
            path_epoch = './Temporary_parameters/inversion_epoch_num.mat'
            scipio.savemat(path_temp, {'temp': temp})
            scipio.savemat(path_epoch, {'epoch_num': epoch_num})
def pre_trained(judge):

    writer = SummaryWriter(
        log_dir='./loss/pre_train_loss_model2/pre_train_loss_SMI_2020_04_17_01'
    )

    if judge == 0:
        model = ConvNet2_2().to(device)
        temp = 10000000000000
        epoch_num = 1
    else:
        model = ConvNet2_2().to(device)
        mode_patch = './model_file/pre_trained_network_model_model2/pre_trained_network_model_SMI_2020_04_17_01.pth'
        model.load_state_dict(torch.load(mode_patch))
        path_temp = './Temporary_parameters/pre_temp_model2.mat'
        temp = scipio.loadmat(path_temp)
        temp = temp['temp'].item()
        path_epoch = './Temporary_parameters/pre_epoch_num_model2.mat'
        epoch_num = scipio.loadmat(path_epoch)
        epoch_num = epoch_num['epoch_num'].item() + 1

    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epoch_num, EPOCHS + 1):
        rand = np.random.randint(0, 60 - BATCH_LEN + 1, 1)

        print(epoch)
        epoch_loss = []

        train_dataset = MyDataset2(
            train_well_seismic[:, rand[0]:rand[0] + BATCH_LEN],
            train_well[:, rand[0]:rand[0] + BATCH_LEN])
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=BATCH_SIZE,
                                      num_workers=5,
                                      shuffle=True,
                                      drop_last=False)

        for itr, (train_dt, train_lable) in enumerate(train_dataloader):

            train_dt, train_lable = train_dt.to(device), train_lable.to(device)
            train_dt = train_dt.float()
            train_lable = train_lable.float()

            model.train()
            optimizer.zero_grad()
            output = model(train_dt)
            loss = F.mse_loss(output, train_lable)
            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.item())
        writer.add_graph(model, (train_dt, ))

        epoch_loss = np.sum(np.array(epoch_loss))
        writer.add_scalar('Train/MSE', epoch_loss, epoch)

        epoch_num = epoch
        print('Train set: Average loss: {:.15f}'.format(epoch_loss))
        if epoch_loss < temp:
            path = './model_file/pre_trained_network_model_model2/pre_trained_network_model_SMI_2020_04_17_01.pth'
            torch.save(model.state_dict(), path)
            temp = epoch_loss
        path_temp = './Temporary_parameters/pre_temp_model2.mat'
        path_epoch = './Temporary_parameters/pre_epoch_num_model2.mat'
        scipio.savemat(path_temp, {'temp': temp})
        scipio.savemat(path_epoch, {'epoch_num': epoch_num})
    writer.close()