Пример #1
0
def test():
    device = torch.device(args.devices if torch.cuda.is_available() else "cpu")
    #test_dataset = Training_Dataset(args.test_dir, (args.image_size,args.image_size),(args.noise, args.noise_param))
    # test_dataset = HongZhang_Dataset("/data_1/data/Noise2Noise/shenqingbiao/0202", "/data_1/data/Noise2Noise/hongzhang")
    test_dataset = HongZhang_TestDataset("/data_1/data/红章图片/test/hongzhang", (256, 256))
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    # choose the model
    if args.model == "unet":
        model = UNet(in_channels=args.image_channels, out_channels=args.image_channels)
    elif args.model == "srresnet":
        model = SRResnet(args.image_channels, args.image_channels)
    elif args.model == "eesp":
        model = EESPNet_Seg(args.image_channels, 2)
    else:
        model = UNet(in_channels=args.image_channels, out_channels=args.image_channels)
    print('loading model')
    # model.load_state_dict(torch.load(model_path))
    # model.eval()
    # model.to(device)
    if args.resume_model:
        resume_model(model, args.resume_model)
        model.eval()
        model.to(device)

    # result_dir = args.denoised_dir
    # if not os.path.exists(result_dir):
    #     os.mkdir(result_dir)

    for batch_idx, image in enumerate(test_loader):
        #PIL_ShowTensor(torch.squeeze(source))
        #PIL_ShowTensor2(torch.squeeze(source),torch.squeeze(noise))
        image = image.to(device)
        denoised_img = model(image).detach().cpu()
        CV2_showTensors(image.cpu(),denoised_img,timeout=5000)
def main():
    logging.basicConfig(level=logging.INFO,
                        format='%(levelname)s: %(message)s')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')
    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="Path to (.yml) config file.")
    parser.add_argument(
        "--load-checkpoint",
        type=str,
        default="",
        help="Path to load saved checkpoint from.",
    )
    configargs = parser.parse_args()
    # Read config file.
    cfg = None
    with open(configargs.config, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        print(cfg_dict)
    # set up network in/out channels details
    # n_channels=3 for RGB images
    # n_classes is the number of probabilities you want to get per pixel
    #   - For 1 class and background, use n_classes=1
    #   - For 2 classes, use n_classes=1
    #   - For N > 2 classes, use n_classes=N
    net = UNet(n_channels=1, n_classes=1, bilinear=True)
    #print(net)
    logging.info(
        f'Network:\n'
        f'\t{net.n_channels} input channels\n'
        f'\t{net.n_classes} output channels (classes)\n'
        f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
    net.to(device=device)
    # set up custome_loss_fn
    custome_loss_fn = None
    if cfg_dict.get('loss_function', None) == 'dice_loss':
        logging.info(
            f"\n Use custom loss function-{cfg_dict.get('loss_function', None)}"
        )
        custome_loss_fn = DiceLoss
    # start training
    TrainNet(Net=net,
             device=device,
             root_imgs_dir=cfg_dict.get('base_dir', None),
             imgs_dir_name=cfg_dict.get("image_dir_suffix", None),
             mask_dir_name=cfg_dict.get("mask_dir_suffix", None),
             dir_checkpoint=cfg_dict.get("checkpoint_dir", None),
             epochs=cfg_dict.get("epochs", 5),
             batch_size=cfg_dict.get("batch_size", 1),
             lr=cfg_dict.get("learning_rate", 0.0001),
             val_percent=cfg_dict.get("validation", 0.2),
             save_checkpoints=True,
             img_scale=cfg_dict.get("scale", 1),
             custome_loss_fn=custome_loss_fn)
Пример #3
0
def main(args):
    dataset_kwargs = {
        'transforms': {},
        'max_length': None,
        'sensor_resolution': None,
        'preload_events': False,
        'num_bins': 16,
        'voxel_method': {
            'method': 'random_k_events',
            'k': 60000,
            't': 0.5,
            'sliding_window_w': 500,
            'sliding_window_t': 0.1
        }
    }

    unet_kwargs = {
        'base_num_channels': 32,  # written as '64' in EVFlowNet tf code
        'num_encoders': 4,
        'num_residual_blocks': 2,  # transition
        'num_output_channels': 2,  # (x, y) displacement
        'skip_type': 'concat',
        'norm': None,
        'use_upsample_conv': True,
        'kernel_size': 3,
        'channel_multiplier': 2,
        'num_bins': 16
    }

    torch.autograd.set_detect_anomaly(True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ev_loader = EventDataLoader(args.h5_file_path,
                                batch_size=1,
                                num_workers=6,
                                shuffle=True,
                                pin_memory=True,
                                dataset_kwargs=dataset_kwargs)

    H, W = ev_loader.H, ev_loader.W

    model = UNet(unet_kwargs)
    model = model.to(device)
    model.train()
    crop = CropParameters(W, H, 4)

    print("=== Let's use", torch.cuda.device_count(), "GPUs!")
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=1e-5,
                                 betas=(0.9, 0.999))
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=0.01)
    # raise
    # tmp_voxel = crop.pad(torch.randn(1, 9, H, W).to(device))
    # F, P = profile(model, inputs=(tmp_voxel, ))

    for idx in range(10):
        # for i, item in enumerate(tqdm(ev_loader)):
        for i, item in enumerate(ev_loader):

            events = item['events']
            voxel = item['voxel'].to(device)
            voxel = crop.pad(voxel)

            model.zero_grad()
            optimizer.zero_grad()

            flow = model(voxel) * 10

            flow = torch.clamp(flow, min=-40, max=40)
            loss = compute_loss(events, flow)
            loss.backward()

            # cvshow_voxel_grid(voxel.squeeze()[0:2].cpu().numpy())
            # raise
            optimizer.step()

            if i % 10 == 0:
                print(
                    idx,
                    i,
                    '\t',
                    "{0:.2f}".format(loss.data.item()),
                    "{0:.2f}".format(torch.max(flow[0, 0]).item()),
                    "{0:.2f}".format(torch.min(flow[0, 0]).item()),
                    "{0:.2f}".format(torch.max(flow[0, 1]).item()),
                    "{0:.2f}".format(torch.min(flow[0, 1]).item()),
                )

                xs, ys, ts, ps = events
                print_voxel = voxel[0].sum(axis=0).cpu().numpy()
                print_flow = flow[0].clone().detach().cpu().numpy()
                print_co = warp_events_with_flow_torch(
                    (xs[0][ps[0] == 1], ys[0][ps[0] == 1], ts[0][ps[0] == 1],
                     ps[0][ps[0] == 1]),
                    flow[0].clone().detach(),
                    sensor_size=(H, W))
                print_co = crop.pad(print_co)
                print_co = print_co.cpu().numpy()

                cvshow_all(idx=idx * 10000 + i,
                           voxel=print_voxel,
                           flow=flow[0].clone().detach().cpu().numpy(),
                           frame=None,
                           compensated=print_co)
Пример #4
0
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(),
                       dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')
                

    writer.close()


if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    net = UNet(n_channels=3, n_classes=1, bilinear=True)
    net.to(device=device)
    # faster convolutions, but more memory
    # cudnn.benchmark = True

    try:
        train_net(net=net,
                  epochs=5,
                  batch_size=1,
                  lr=0.001,
                  device=device,
                  val_percent=10.0 / 100)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        logging.info('Saved interrupt')
        try:
            sys.exit(0)
Пример #5
0
def predict(input_images: list = None,
            target_images: list = None,
            config_file: str = None,
            save_file_suffix=None):
    '''
    Compute output masked and its contours graphs given the "list" of input images filenames.
    Args:
       input_images (list[str]): list of input images filenames, if None then input filenames are given by argument list instead
       target_images (list[str]): list of target images mask filenames, if None then target filenames are given by argument list instead
       config_file  (list[str]): path to the configuation file that specify evaluation detail, 
            if None then config file path are given by argument list instead
    Returns:
        out_files (list[str]): list of output maksed filenames
        countors_outs_files (list[str]): list of countours of output maksed filenames
        dc_val_records (list[float]): list of dice coefficient of each target masked image and output(predicted) masked image
    '''
    logging.basicConfig(level=logging.INFO,
                        format='%(levelname)s: %(message)s')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    # specify configuration file
    if config_file is None:
        parser = argparse.ArgumentParser()
        parser.add_argument("--config",
                            type=str,
                            help="Path to (.yml) config file.")
        parser.add_argument('--input_images',
                            '-i',
                            metavar='INPUT',
                            nargs='+',
                            help='filenames of input images')

        parser.add_argument('--target_images',
                            '-t',
                            metavar='INPUT',
                            nargs='+',
                            help='filenames of target mask images')
        configargs = parser.parse_args()
        config_file = configargs.config

    # Read config file.
    cfg = None
    with open(config_file, "r") as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        print(cfg_dict)
    # set up network in/out channels details
    # n_channels=3 for RGB images
    # n_classes is the number of probabilities you want to get per pixel
    #   - For 1 class and background, use n_classes=1
    #   - For 2 classes, use n_classes=1
    #   - For N > 2 classes, use n_classes=N
    net = UNet(n_channels=1, n_classes=1, bilinear=True)
    #print(net)
    net.to(device=device)
    net.load_state_dict(
        torch.load(cfg_dict.get('model_weights', None), map_location=device))
    # In the case of ignoring parameters, input filenames are given by argument list instead
    if input_images is None:
        input_images = configargs.input_images
    if target_images is None:
        target_images = configargs.target_images
    logging.info("Model loaded !")
    out_files = get_output_filenames(in_files=input_images,
                                     output_dir=cfg_dict.get(
                                         'output_dir', None),
                                     suffix=save_file_suffix)
    countors_outs_files = []
    dc_val_records = []
    # start evaluating
    for i, (filename,
            target_filename) in enumerate(zip(input_images, target_images)):
        logging.info(
            f"\nPredicting image {filename}, Target image {target_filename}")

        img = Image.open(filename)
        target = Image.open(target_filename)

        mask, dc_val = predict_img(net=net,
                                   full_img=img,
                                   target_img=target,
                                   scale_factor=cfg_dict.get('scale', 1),
                                   out_threshold=cfg_dict.get(
                                       'mask_threshold', 0.5),
                                   device=device)
        if cfg_dict.get('save', True):
            out_filename = out_files[i]
            result, contours = mask_to_image(mask, fn=contours_fn)
            result.save(out_files[i])
            out_contour = out_files[i].replace(".jpg", "-contour.jpg")
            contours.save(out_contour)
            countors_outs_files.append(out_contour)
            # Record DC value for evaluation
            dc_val_records.append(dc_val)
            logging.info(
                f"\nMask saved to {out_files[i]}, Countour saved to {out_contour}"
            )
    return out_files, countors_outs_files, dc_val_records
Пример #6
0
def train():
    # prepare the dataloader
    device = torch.device(args.devices if torch.cuda.is_available() else "cpu")
    #dataset = Training_Dataset(args.image_dir, (args.image_size, args.image_size), (args.noise,args.noise_param))
    # dataset = HongZhang_Dataset("/data_1/data/Noise2Noise/shenqingbiao/0202", "/data_1/data/Noise2Noise/hongzhang")
    # dataset = HongZhang_Dataset2("/data_1/data/红章图片", (256, 256))
    dataset = HongZhang_Dataset3("/data_1/data/红章图片/6_12", (256, 256))
    dataset_length = len(dataset)
    train_loader = DataLoader(dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=4)

    # choose the model
    if args.model == "unet":
        model = UNet(in_channels=args.image_channels,
                     out_channels=args.image_channels)
    elif args.model == "srresnet":
        model = SRResnet(args.image_channels, args.image_channels)
    elif args.model == "eesp":
        model = EESPNet_Seg(args.image_channels, 2)
    else:
        model = UNet(in_channels=args.image_channels,
                     out_channels=args.image_channels)
    model = model.to(device)

    # choose the loss type
    if args.loss == "l2":
        criterion = nn.MSELoss()
    elif args.loss == "l1":
        criterion = nn.L1Loss()
    elif args.loss == "ssim":
        criterion = SSIM()

    # resume the mode if needed
    if args.resume_model:
        resume_model(model, args.resume_model)

    optim = Adam(model.parameters(),
                 lr=args.lr,
                 betas=(0.9, 0.999),
                 eps=1e-8,
                 weight_decay=0,
                 amsgrad=True)
    #scheduler = lr_scheduler.StepLR(optim, step_size=args.scheduler_step, gamma=0.5)
    scheduler = lr_scheduler.MultiStepLR(optim, milestones=[20, 40], gamma=0.1)
    model.train()
    print(model)

    # start to train
    print("Starting Training Loop...")
    since = time.time()
    for epoch in range(args.epochs):
        print('Epoch {}/{}'.format(epoch, args.epochs - 1))
        print('-' * 10)
        running_loss = 0.0
        scheduler.step()
        for batch_idx, (target, source) in enumerate(train_loader):
            source = source.to(device)
            target = target.to(device)
            denoised_source = model(source)
            if args.loss == "ssim":
                loss = 1 - criterion(denoised_source, Variable(target))
            else:
                loss = criterion(denoised_source, Variable(target))
            optim.zero_grad()
            loss.backward()
            optim.step()

            running_loss += loss.item() * source.size(0)
            if batch_idx % args.steps_show == 0:
                print('{}/{} Current loss {}'.format(batch_idx,
                                                     len(train_loader),
                                                     loss.item()))
        epoch_loss = running_loss / dataset_length
        print('{} Loss: {:.4f}'.format('current ' + str(epoch), epoch_loss))
        if (epoch + 1) % args.save_per_epoch == 0:
            save_model(model, epoch + 1)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
Пример #7
0
                    if net.n_classes > 1:
                        logging.info(
                            'Validation cross entropy: {}'.format(val_score))
                        # writer.add_scalar('Loss/test', val_score, global_step)
                    else:
                        logging.info(
                            'Validation Dice Coeff: {}'.format(val_score))
                        # writer.add_scalar('Dice/test', val_score, global_step)

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            if ((epoch + 1) % 20 == 0):
                torch.save(net.state_dict(),
                           dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
                logging.info(f'Checkpoint {epoch + 1} saved !')


if __name__ == '__main__':

    net = UNet(n_channels=3, n_classes=1)
    device = "cuda:0"
    net = net.to(device)
    # net.load_state_dict(torch.load(r"checkpoints/CP_epoch10.pth", map_location=device))
    # print(f'Model loaded')
    # wandb.watch(net)
    train_net(net, device=device)
Пример #8
0
def train(train_sources, eval_source):
    path = sys.argv[1]
    dr = DataReader(path, train_sources)
    dr.read()
    print(len(dr.train.x))

    batch_size = 8
    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')

    dataset_s_train = MultiDomainDataset(dr.train.x, dr.train.y, dr.train.vendor, device, DomainAugmentation())
    dataset_s_dev = MultiDomainDataset(dr.dev.x, dr.dev.y, dr.dev.vendor, device)
    dataset_s_test = MultiDomainDataset(dr.test.x, dr.test.y, dr.test.vendor, device)
    loader_s_train = DataLoader(dataset_s_train, batch_size, shuffle=True)

    dr_eval = DataReader(path, [eval_source])
    dr_eval.read()

    dataset_eval_dev = MultiDomainDataset(dr_eval.dev.x, dr_eval.dev.y, dr_eval.dev.vendor, device)
    dataset_eval_test = MultiDomainDataset(dr_eval.test.x, dr_eval.test.y, dr_eval.test.vendor, device)

    dataset_da_train = MultiDomainDataset(dr.train.x+dr_eval.train.x, dr.train.y+dr_eval.train.y, dr.train.vendor+dr_eval.train.vendor, device, DomainAugmentation())
    loader_da_train = DataLoader(dataset_da_train, batch_size, shuffle=True)

    segmentator = UNet()
    discriminator = Discriminator(n_domains=len(train_sources))
    discriminator.to(device)
    segmentator.to(device)

    sigmoid = nn.Sigmoid()
    selector = Selector()

    s_criterion = nn.BCELoss()
    d_criterion = nn.CrossEntropyLoss()
    s_optimizer = optim.AdamW(segmentator.parameters(), lr=0.0001, weight_decay=0.01)
    d_optimizer = optim.AdamW(discriminator.parameters(), lr=0.001, weight_decay=0.01)
    a_optimizer = optim.AdamW(segmentator.encoder.parameters(), lr=0.001, weight_decay=0.01)
    lmbd = 1/150
    s_train_losses = []
    s_dev_losses = []
    d_train_losses = []
    eval_domain_losses = []
    train_dices = []
    dev_dices = []
    eval_dices = []
    epochs = 3
    da_loader_iter = iter(loader_da_train)
    for epoch in tqdm(range(epochs)):
        s_train_loss = 0.0
        d_train_loss = 0.0
        for index, sample in enumerate(loader_s_train):
            img = sample['image']
            target_mask = sample['target']

            da_sample = next(da_loader_iter, None)
            if epoch == 100:
                s_optimizer.defaults['lr'] = 0.001
                d_optimizer.defaults['lr'] = 0.0001
            if da_sample is None:
                da_loader_iter = iter(loader_da_train)
                da_sample = next(da_loader_iter, None)
            if epoch < 50 or epoch >= 100:
                # Training step of segmentator
                predicted_activations, inner_repr = segmentator(img)
                predicted_mask = sigmoid(predicted_activations)
                s_loss = s_criterion(predicted_mask, target_mask)
                s_optimizer.zero_grad()
                s_loss.backward()
                s_optimizer.step()
                s_train_loss += s_loss.cpu().detach().numpy()

            if epoch >= 50:
                # Training step of discriminator
                predicted_activations, inner_repr = segmentator(da_sample['image'])
                predicted_activations = predicted_activations.clone().detach()
                inner_repr = inner_repr.clone().detach()
                predicted_vendor = discriminator(predicted_activations, inner_repr)
                d_loss = d_criterion(predicted_vendor, da_sample['vendor'])
                d_optimizer.zero_grad()
                d_loss.backward()
                d_optimizer.step()
                d_train_loss += d_loss.cpu().detach().numpy()

            if epoch >= 100:
                # adversarial training step
                predicted_mask, inner_repr = segmentator(da_sample['image'])
                predicted_vendor = discriminator(predicted_mask, inner_repr)
                a_loss = -1 * lmbd * d_criterion(predicted_vendor, da_sample['vendor'])
                a_optimizer.zero_grad()
                a_loss.backward()
                a_optimizer.step()
                lmbd += 1/150
        inference_model = nn.Sequential(segmentator, selector, sigmoid)
        inference_model.to(device)
        inference_model.eval()
        d_train_losses.append(d_train_loss / len(loader_s_train))
        s_train_losses.append(s_train_loss / len(loader_s_train))
        s_dev_losses.append(calculate_loss(dataset_s_dev, inference_model, s_criterion, batch_size))
        eval_domain_losses.append(calculate_loss(dataset_eval_dev, inference_model, s_criterion, batch_size))

        train_dices.append(calculate_dice(inference_model, dataset_s_train))
        dev_dices.append(calculate_dice(inference_model, dataset_s_dev))
        eval_dices.append(calculate_dice(inference_model, dataset_eval_dev))

        segmentator.train()

    date_time = datetime.now().strftime("%m%d%Y_%H%M%S")
    model_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "model", "weights", "segmentator"+str(date_time)+".pth")
    torch.save(segmentator.state_dict(), model_path)

    util.plot_data([(s_train_losses, 'train_losses'), (s_dev_losses, 'dev_losses'), (d_train_losses, 'discriminator_losses'),
               (eval_domain_losses, 'eval_domain_losses')],
              'losses.png')
    util.plot_dice([(train_dices, 'train_dice'), (dev_dices, 'dev_dice'), (eval_dices, 'eval_dice')],
              'dices.png')

    inference_model = nn.Sequential(segmentator, selector, sigmoid)
    inference_model.to(device)
    inference_model.eval()

    print('Dice on annotated: ', calculate_dice(inference_model, dataset_s_test))
    print('Dice on unannotated: ', calculate_dice(inference_model, dataset_eval_test))
Пример #9
0
def train(args):
    result_path = 'result/%s/'%args.model
    if not os.path.exists(result_path):
        os.makedirs(result_path)
        os.makedirs('%simage'%result_path)
        os.makedirs('%scheckpoint'%result_path)

    train_set = MyDataset('train', args.label_type, 512)
    train_loader = DataLoader(
        train_set,
        batch_size=args.batchsize,
        shuffle=True, 
        num_workers=args.num_workers)

    # device = 'cuda:0'
    device = 'cuda:6' if torch.cuda.device_count()>1 else 'cuda:0'


    out_channels = 1 if args.label_type=='msk' else 2 
    print(out_channels)
    if args.model=='unet':
        print('using unet as model!')
        model = UNet(out_channels=out_channels)
    elif args.model=='deeplab':
        print('using deeplab as model!')
        model = torch.hub.load('pytorch/vision:v0.9.0',
                 'deeplabv3_resnet101', pretrained=False)
    else:
        print('no model!')

    model = model.to(device)
    # model = nn.DataParallel(model)

    img_show = train_set.__getitem__(0)['x']
    img_show = torch.tensor(img_show).to(device).float()[None, :]

    model.train()
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    loss_list = []
    loss_best = 10
    for epo in tqdm(range(1, args.epochs+1), ascii=True):
        epo_loss = []
        for idx, item in enumerate(train_loader):
            x = item['x'].to(device, dtype=torch.float)
            y = item['y'].to(device, dtype=torch.float)

            optimizer.zero_grad()

            if args.model=='unet':
                pred = model(x)
            elif args.model=='deeplab':
                pred = model(x)['out'][:,0][:,None]

            # print(y.shape, pred.shape)
            loss = criterion(pred, y)
            
            # print(loss.item())
            epo_loss.append(loss.data.item())

            loss.backward()
            optimizer.step()

        epo_loss_mean = np.array(epo_loss).mean()
        # print(epo_loss_mean)
        loss_list.append(epo_loss_mean)
        plot_loss(loss_list, '%simage/loss.png'%result_path)

        with torch.no_grad():
            if args.model=='unet':
                pred = model(img_show.clone())
            elif args.model=='deeplab':
                pred = model(img_show.clone())['out'][:,0][:,None]
            # y = model(img_show)
            # print(img_show.shape)
            if args.label_type=='msk':
                x = img_show[0].cpu().detach().numpy().transpose((1,2,0))
                y = pred[0, 0].cpu().detach().numpy()
            elif args.label_type=='flow':
                x = img_show[0].cpu().detach().numpy().transpose((1,2,0))
                y = pred[0].cpu().detach().numpy().transpose((1,2,0))
                
            plt.subplot(121)
            plt.imshow(x*255)       
            plt.subplot(122)
            plt.imshow(y[:,:,0])  
            plt.savefig('%simage/%d.png'%(result_path, epo))     
            plt.clf()
        #loss
        if epo % 3 ==0:
            torch.save(model, '%scheckpoint/%d.pt'%(result_path, epo))
            if epo_loss_mean < loss_best:
                loss_best = epo_loss_mean
                torch.save(model, '%scheckpoint/best.pt'%(result_path))
            np.save('%sloss.npy'%result_path, np.array(loss_list))