예제 #1
0
    def initialize(self, dataroots, load_size=64):
        if (not isinstance(dataroots, list)):
            dataroots = [
                dataroots,
            ]
        self.roots = dataroots
        self.load_size = load_size

        # image directory
        self.dir_ref = [os.path.join(root, 'ref') for root in self.roots]
        self.ref_paths = make_dataset(self.dir_ref)
        self.ref_paths = sorted(self.ref_paths)

        self.dir_p0 = [os.path.join(root, 'p0') for root in self.roots]
        self.p0_paths = make_dataset(self.dir_p0)
        self.p0_paths = sorted(self.p0_paths)

        self.dir_p1 = [os.path.join(root, 'p1') for root in self.roots]
        self.p1_paths = make_dataset(self.dir_p1)
        self.p1_paths = sorted(self.p1_paths)

        transform_list = []
        transform_list.append(transforms.Scale(load_size))
        transform_list += [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

        self.transform = transforms.Compose(transform_list)

        # judgement directory
        self.dir_J = [os.path.join(root, 'judge') for root in self.roots]
        self.judge_paths = make_dataset(self.dir_J, mode='np')
        self.judge_paths = sorted(self.judge_paths)
예제 #2
0
    def initialize(self, dataroot, load_size=64):
        self.root = dataroot
        self.load_size = load_size

        self.dir_p0 = os.path.join(self.root, 'p0')
        self.p0_paths = make_dataset(self.dir_p0)
        self.p0_paths = sorted(self.p0_paths)

        self.dir_p1 = os.path.join(self.root, 'p1')
        self.p1_paths = make_dataset(self.dir_p1)
        self.p1_paths = sorted(self.p1_paths)

        transform_list = []
        transform_list.append(transforms.Scale(load_size))
        transform_list += [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

        self.transform = transforms.Compose(transform_list)

        # judgement directory
        self.dir_S = os.path.join(self.root, 'same')
        self.same_paths = make_dataset(self.dir_S, mode='np')
        self.same_paths = sorted(self.same_paths)
예제 #3
0
def main():

    # mode argument
    args = argparse.ArgumentParser()
    args.add_argument("--num_classes", type=int, default=22)
    args.add_argument("--lr", type=int, default=0.005)
    args.add_argument("--cuda", type=bool, default=True)
    args.add_argument("--num_epochs", type=int, default=30)
    args.add_argument("--model_name", type=str, default="mask_1.pth")
    args.add_argument("--prediction_file", type=str, default="prediction")
    args.add_argument("--batch", type=int, default=16)
    args.add_argument("--mode", type=str, default="train")
    
    config = args.parse_args()

    num_classes = config.num_classes
    base_lr = config.lr
    cuda = config.cuda
    num_epochs = config.num_epochs
    model_name = config.model_name
    prediction_file = config.prediction_file
    batch = config.batch
    mode = config.mode

    # 도움 함수를 이용해 모델을 가져옵니다
    new_model = model.get_model_instance_segmentation(num_classes)

 # 학습을 GPU로 진행하되 GPU가 가용하지 않으면 CPU로 합니다
    device = torch.device('cuda') if cuda else torch.device('cpu')

    # 모델을 GPU나 CPU로 옮깁니다
    new_model.to(device)

    if mode == 'train':
        # 데이터셋과 정의된 변환들을 사용합니다
        '''dataset = CustomDataset(DATASET_PATH, dataloader.get_transform(train=True))
        dataset_val = CustomDataset(DATASET_PATH, dataloader.get_transform(train=False))
        # 데이터셋을 학습용과 테스트용으로 나눕니다(역자주: 여기서는 전체의 50개를 테스트에, 나머지를 학습에 사용합니다)
        indices = torch.randperm(len(dataset)).tolist()
        dataset = torch.utils.data.Subset(dataset, indices[:-10])
        dataset_val = torch.utils.data.Subset(dataset_val, indices[-10:])
        data_loader_val = torch.utils.data.DataLoader(
            dataset_val, batch_size=batch, shuffle=True,num_workers=4, collate_fn=dataloader.collate_fn)
    '''
        dataset =dataloader.make_dataset(DATASET_PATH)
        # 데이터 로더를 학습용과 검증용으로 정의합니다
        dataset_loader = torch.utils.data.DataLoader(
            dataset, batch_size=batch, shuffle=True, num_workers=4,collate_fn=dataloader.collate_fn)

        # 옵티마이저(Optimizer)를 만듭니다
        params = [p for p in new_model.parameters() if p.requires_grad]
        optimizer = torch.optim.SGD(params, lr=base_lr,
                                    momentum=0.9, weight_decay=0.0005)
            
        # 학습률 스케쥴러를 만듭니다
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.1)
        
        train(new_model, dataset_loader, device, num_epochs, optimizer=optimizer, lr_scheduler=lr_scheduler)
    
    elif mode == 'test' :
        dataset_test =dataloader.make_dataset(DATASET_PATH)

        data_loader_test = torch.utils.data.DataLoader(
            dataset_test, batch_size=1, shuffle=False, collate_fn=dataloader.collate_fn)
        
        load_model(model_name, new_model)
        test(new_model, data_loader_test, device, prediction_file)
    
        

    print("That's it!")
예제 #4
0
def main():
    # save input stats for later use

    print(args.work_dir, args.exp)
    work_dir = os.path.join(args.work_dir, args.exp)
    if not os.path.exists(work_dir):
        os.makedirs(work_dir)

    # copy this file to work dir to keep training configuration
    shutil.copy(__file__, os.path.join(work_dir, 'main.py'))
    with open(os.path.join(work_dir, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    # transform
    transform1 = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([0.5], [0.5])])

    # 1.train_dataset

    train_path, test_path = loader.make_dataset(args.train_site,
                                                train_size=args.train_size,
                                                mode='train')

    np.save(os.path.join(work_dir, '{}_test_path.npy'.format(args.train_site)),
            test_path)

    train_image_path = train_path[0]
    train_label_path = train_path[1]
    test_image_path = test_path[0]
    test_label_path = test_path[1]

    train_dataset = loader.CustomDataset(train_image_path,
                                         train_label_path,
                                         args.train_site,
                                         args.input_size,
                                         transform1,
                                         arg_mode=args.arg_mode,
                                         arg_thres=args.arg_thres)
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=4)

    val_dataset = loader.CustomDataset(test_image_path,
                                       test_label_path,
                                       args.train_site,
                                       args.input_size,
                                       transform1,
                                       arg_mode=False)
    val_loader = data.DataLoader(val_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=4)

    Train_test_dataset = loader.CustomDataset(test_image_path, test_label_path,
                                              args.train_site, args.input_size,
                                              transform1)
    Train_test_loader = data.DataLoader(Train_test_dataset,
                                        batch_size=1,
                                        shuffle=True,
                                        num_workers=4)

    trn_logger = Logger(os.path.join(work_dir, 'train.log'))
    trn_raw_logger = Logger(os.path.join(work_dir, 'train_raw.log'))
    val_logger = Logger(os.path.join(work_dir, 'validation.log'))

    # 3.model_select
    my_net, model_name = model_select(
        args.arch,
        args.input_size,
    )

    # 4.gpu select
    my_net = nn.DataParallel(my_net).cuda()
    cudnn.benchmark = True

    # 5.optim

    if args.optim == 'adam':
        gen_optimizer = torch.optim.Adam(my_net.parameters(),
                                         lr=args.initial_lr,
                                         eps=args.eps)
    elif args.optim == 'sgd':
        gen_optimizer = torch.optim.SGD(my_net.parameters(),
                                        lr=args.initial_lr,
                                        momentum=0.9,
                                        weight_decay=args.weight_decay)

    # lr decay
    lr_schedule = args.lr_schedule
    lr_scheduler = optim.lr_scheduler.MultiStepLR(gen_optimizer,
                                                  milestones=lr_schedule[:-1],
                                                  gamma=args.gamma)

    # 6.loss
    if args.loss_function == 'bce':
        criterion = nn.BCEWithLogitsLoss(
            pos_weight=torch.Tensor([args.bce_weight])).cuda()
    elif args.loss_function == 'mse':
        criterion = nn.MSELoss().cuda()


#####################################################################################

# train

    send_slack_message(args.token, '#jm_private',
                       '{} : starting_training'.format(args.exp))
    best_iou = 0
    try:
        if args.train_mode:
            for epoch in range(lr_schedule[-1]):

                train(my_net, train_loader, gen_optimizer, epoch, criterion,
                      trn_logger, trn_raw_logger)
                iou = validate(val_loader,
                               my_net,
                               criterion,
                               epoch,
                               val_logger,
                               save_fig=False,
                               work_dir_name='jsrt_visualize_per_epoch')
                print(
                    'validation_iou **************************************************************'
                )

                lr_scheduler.step()

                if args.val_size == 0:
                    is_best = 1
                else:
                    is_best = iou > best_iou
                best_iou = max(iou, best_iou)
                checkpoint_filename = 'model_checkpoint_{:0>3}.pth'.format(
                    epoch + 1)
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': my_net.state_dict(),
                        'optimizer': gen_optimizer.state_dict()
                    },
                    is_best,
                    work_dir,
                    filename='checkpoint.pth')

        print("train end")
    except RuntimeError as e:
        send_slack_message(
            args.token, '#jm_private',
            '-----------------------------------  error train : send to message JM  & Please send a kakao talk ----------------------------------------- \n error message : {}'
            .format(e))
        import ipdb
        ipdb.set_trace()

    draw_curve(work_dir, trn_logger, val_logger)
    send_slack_message(args.token, '#jm_private',
                       '{} : end_training'.format(args.exp))

    if args.test_mode:
        print('Test mode ...')
        main_test(model=my_net, test_loader=test_data_list, args=args)
예제 #5
0
def main():
    if args.server == 'server_A':
        work_dir = os.path.join('/data1/JM/lung_segmentation', args.exp)
        print(work_dir)
    elif args.server == 'server_B':
        work_dir = os.path.join('/data1/workspace/JM_gen/lung_seg', args.exp)
        print(work_dir)

    if not os.path.exists(work_dir):
        os.makedirs(work_dir)

    # copy this file to work dir to keep training configuration
    shutil.copy(__file__, os.path.join(work_dir, 'main.py'))
    with open(os.path.join(work_dir, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    # transform
    transform1 = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([0.5], [0.5])])

    # 1.train_dataset
    if args.val_size == 0:
        train_path, test_path = loader.make_dataset(args.server,
                                                    args.train_dataset +
                                                    '_dataset',
                                                    train_size=args.train_size)

        np.save(
            os.path.join(work_dir,
                         '{}_test_path.npy'.format(args.train_dataset)),
            test_path)

        train_image_path = train_path[0]
        train_label_path = train_path[1]
        test_image_path = test_path[0]
        test_label_path = test_path[1]

        train_dataset = loader.CustomDataset(train_image_path,
                                             train_label_path,
                                             transform1,
                                             arg_mode=args.arg_mode,
                                             arg_thres=args.arg_thres,
                                             arg_range=args.arg_range,
                                             dataset=args.train_dataset)
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=4)

        # Organize images and labels differently.
        train_dataset_random = loader.CustomDataset(train_image_path,
                                                    train_label_path,
                                                    transform1,
                                                    arg_mode=args.arg_mode,
                                                    arg_thres=args.arg_thres,
                                                    arg_range=args.arg_range,
                                                    dataset=args.train_dataset)
        train_loader_random = data.DataLoader(train_dataset_random,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=4)

        val_dataset = loader.CustomDataset(test_image_path,
                                           test_label_path,
                                           transform1,
                                           arg_mode=False,
                                           dataset=args.train_dataset)
        val_loader = data.DataLoader(val_dataset,
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=4)

        # 'JSRT' test_dataset
        Train_test_dataset = loader.CustomDataset(test_image_path,
                                                  test_label_path,
                                                  transform1,
                                                  dataset=args.train_dataset)
        Train_test_loader = data.DataLoader(Train_test_dataset,
                                            batch_size=1,
                                            shuffle=True,
                                            num_workers=4)

        # 2.test_dataset_path

        # 'MC'test_dataset
        test_data1_path, _ = loader.make_dataset(args.server,
                                                 args.test_dataset1 +
                                                 '_dataset',
                                                 train_size=1)
        test_data1_dataset = loader.CustomDataset(test_data1_path[0],
                                                  test_data1_path[1],
                                                  transform1,
                                                  dataset=args.test_dataset1)
        test_data1_loader = data.DataLoader(test_data1_dataset,
                                            batch_size=1,
                                            shuffle=True,
                                            num_workers=4)

        # 'sh'test_dataset
        test_data2_path, _ = loader.make_dataset(args.server,
                                                 args.test_dataset2 +
                                                 '_dataset',
                                                 train_size=1)
        test_data2_dataset = loader.CustomDataset(test_data2_path[0],
                                                  test_data2_path[1],
                                                  transform1,
                                                  dataset=args.test_dataset2)
        test_data2_loader = data.DataLoader(test_data2_dataset,
                                            batch_size=1,
                                            shuffle=True,
                                            num_workers=0)

        test_data_list = [
            Train_test_loader, test_data1_loader, test_data2_loader
        ]

        # np.save(os.path.join(work_dir, 'input_stats.npy'), train_dataset.input_stats)

        trn_logger = Logger(os.path.join(work_dir, 'train.log'))
        trn_raw_logger = Logger(os.path.join(work_dir, 'train_raw.log'))
        val_logger = Logger(os.path.join(work_dir, 'validation.log'))

    # 3.model_select
    model_seg, model_name = model_select(args.arch_seg)
    model_ae, _ = model_select(args.arch_ae)

    # 4.gpu select
    model_seg = nn.DataParallel(model_seg).cuda()
    model_ae = nn.DataParallel(model_ae).cuda()

    #model_seg = model_seg.cuda()
    #model_ae = model_ae.cuda()

    cudnn.benchmark = True

    # 5.optim
    if args.optim == 'adam':
        optimizer_seg = torch.optim.Adam(model_seg.parameters(),
                                         lr=args.initial_lr)
        optimizer_ae = torch.optim.Adam(model_ae.parameters(),
                                        lr=args.initial_lr)

    elif args.optim == 'sgd':
        optimizer_seg = torch.optim.SGD(model_seg.parameters(),
                                        lr=args.initial_lr,
                                        weight_decay=args.weight_decay)

        optimizer_ae = torch.optim.SGD(model_ae.parameters(),
                                       lr=args.initial_lr,
                                       weight_decay=args.weight_decay)

    # if args.clip_grad :
    #
    #     import torch.nn.utils as torch_utils
    #     max_grad_norm = 1.
    #
    #     torch_utils.clip_grad_norm_(model_seg.parameters(),
    #                                 max_grad_norm
    #                                 )
    #     torch_utils.clip_grad_norm_(model_ae.parameters(),
    #                                 max_grad_norm
    #                                 )

    # lr decay
    lr_schedule = args.lr_schedule
    lr_scheduler_seg = optim.lr_scheduler.MultiStepLR(
        optimizer_seg, milestones=lr_schedule[:-1], gamma=args.gamma)

    lr_scheduler_ae = optim.lr_scheduler.MultiStepLR(
        optimizer_ae, milestones=lr_schedule[:-1], gamma=args.gamma)

    # 6.loss

    criterion_seg = loss_function_select(args.seg_loss_function)
    criterion_ae = loss_function_select(args.ae_loss_function)
    criterion_embedding = loss_function_select(args.embedding_loss_function)

    #####################################################################################

    # train

    send_slack_message('#jm_private',
                       '{} : starting_training'.format(args.exp))
    best_iou = 0
    try:
        if args.train_mode:
            for epoch in range(lr_schedule[-1]):

                train(model_seg=model_seg,
                      model_ae=model_ae,
                      train_loader=train_loader,
                      train_loder_random=train_loader_random,
                      optimizer_seg=optimizer_seg,
                      optimizer_ae=optimizer_ae,
                      criterion_seg=criterion_seg,
                      criterion_ae=criterion_ae,
                      criterion_embedding=criterion_embedding,
                      epoch=epoch,
                      logger=trn_logger,
                      sublogger=trn_raw_logger)

                iou = validate(model=model_seg,
                               val_loader=val_loader,
                               criterion=criterion_seg,
                               epoch=epoch,
                               logger=val_logger,
                               work_dir=work_dir,
                               save_fig=False,
                               work_dir_name='{}_visualize_per_epoch'.format(
                                   args.train_dataset))
                print(
                    'validation result **************************************************************'
                )

                lr_scheduler_seg.step()
                lr_scheduler_ae.step()

                if args.val_size == 0:
                    is_best = 1
                else:
                    is_best = iou > best_iou

                best_iou = max(iou, best_iou)
                checkpoint_filename = 'model_checkpoint_{:0>3}.pth'.format(
                    epoch + 1)
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model_seg.state_dict(),
                        'optimizer': optimizer_seg.state_dict()
                    },
                    is_best,
                    work_dir,
                    filename='checkpoint.pth')

        print("train end")
    except RuntimeError as e:
        send_slack_message(
            '#jm_private',
            '-----------------------------------  error train : send to message JM  & Please send a kakao talk ----------------------------------------- \n error message : {}'
            .format(e))

        import ipdb
        ipdb.set_trace()

    draw_curve(work_dir, trn_logger, val_logger)
    send_slack_message('#jm_private', '{} : end_training'.format(args.exp))
    #--------------------------------------------------------------------------------------------------------#
    #here is load model for last pth
    load_filename = os.path.join(work_dir, 'model_best.pth')
    checkpoint = torch.load(load_filename)
    ch_epoch = checkpoint['epoch']
    save_check_txt = os.path.join(work_dir, str(ch_epoch))
    f = open("{}_best_checkpoint.txt".format(save_check_txt), 'w')
    f.close()

    # --------------------------------------------------------------------------------------------------------#

    # validation
    if args.test_mode:
        print('Test mode ...')
        main_test(model=model_seg, test_loader=test_data_list, args=args)
예제 #6
0
def main():

    # mode argument
    args = argparse.ArgumentParser()
    args.add_argument("--num_classes", type=int, default=22)
    args.add_argument("--lr", type=int, default=0.001)
    args.add_argument("--cuda", type=bool, default=True)
    args.add_argument("--num_epochs", type=int, default=2)
    args.add_argument("--model_name", type=str, default="weights/0.pth")
    args.add_argument("--prediction_file", type=str, default="prediction")
    args.add_argument("--batch", type=int, default=16)
    args.add_argument("--mode", type=str, default="train")

    config = args.parse_args()

    num_classes = config.num_classes
    base_lr = config.lr
    cuda = config.cuda
    num_epochs = config.num_epochs
    model_name = config.model_name
    prediction_file = config.prediction_file
    batch = config.batch
    mode = config.mode

    # 도움 함수를 이용해 모델을 가져옵니다
    #new_model = model.get_model_instance_segmentation3(num_classes)
    #new_model = model.get_model_instance_segmentation(num_classes)
    #new_model = model.get_model_instance_segmentation2(num_classes)
    #new_model = model.get_model_instance_segmentation4(num_classes)
    #new_model = model.get_model_instance_segmentation5(num_classes)
    #new_model = model.get_model_instance_segmentation6(num_classes)
    #new_model = get_model_instance_segmentation_custom0(num_classes)
    new_model = get_model_instance_segmentation_custom1(num_classes)
    #get_model_instance_segmentation_custom1
    # 학습을 GPU로 진행하되 GPU가 가용하지 않으면 CPU로 합니다
    device = torch.device('cuda') if cuda else torch.device('cpu')

    # 모델을 GPU나 CPU로 옮깁니다
    new_model.to(device)

    if mode == 'train':
        # 데이터셋과 정의된 변환들을 사용합니다
        '''dataset = CustomDataset(DATASET_PATH, dataloader.get_transform(train=True))
        dataset_val = CustomDataset(DATASET_PATH, dataloader.get_transform(train=False))
        # 데이터셋을 학습용과 테스트용으로 나눕니다(역자주: 여기서는 전체의 50개를 테스트에, 나머지를 학습에 사용합니다)
        indices = torch.randperm(len(dataset)).tolist()
        dataset = torch.utils.data.Subset(dataset, indices[:-10])
        dataset_val = torch.utils.data.Subset(dataset_val, indices[-10:])
        data_loader_val = torch.utils.data.DataLoader(
            dataset_val, batch_size=batch, shuffle=True,num_workers=4, collate_fn=dataloader.collate_fn)
    '''
        dataset = dataloader.make_dataset(DATASET_PATH)
        # 데이터 로더를 학습용과 검증용으로 정의합니다
        dataset_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch,
            shuffle=True,
            num_workers=0,
            collate_fn=dataloader.collate_fn)

        # 옵티마이저(Optimizer)를 만듭니다
        params = [p for p in new_model.parameters() if p.requires_grad]
        #optimizer = torch.optim.SGD(params, lr=base_lr,
        #                            momentum=0.9, weight_decay=0.0005)

        # 학습률 스케쥴러를 만듭니다
        #lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.1)
        optimizer = torch.optim.Adam(params, lr=base_lr)
        #scheduler = StepLR(optimizer, step_size=40, gamma=0.1)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=num_epochs, eta_min=0.)
        # check parameter of model
        # 41401661 - fasterrcnn_resnet50_fpn
        print("------------------------------------------------------------")
        total_params = sum(p.numel() for p in new_model.parameters())
        print("num of parameter : ", total_params)
        trainable_params = sum(p.numel() for p in new_model.parameters()
                               if p.requires_grad)
        print("num of trainable_ parameter :", trainable_params)
        print("------------------------------------------------------------")
        print(new_model.eval())

        print(new_model.state_dict().keys())

        train(new_model,
              dataset_loader,
              device,
              num_epochs,
              optimizer=optimizer,
              lr_scheduler=lr_scheduler)
    elif mode == 'pruning':
        load_model(model_name, new_model)
        print(new_model.eval())
        print(new_model.state_dict().keys())

        print(
            "before------------------------------------------------------------"
        )
        total_params = sum(p.numel() for p in new_model.parameters())
        print("num of parameter : ", total_params)
        trainable_params = sum(p.numel() for p in new_model.parameters()
                               if p.requires_grad)
        print("num of trainable_ parameter :", trainable_params)
        print("------------------------------------------------------------")

        for name, module in new_model.named_modules():
            # prune 20% of connections in all 2D-conv layers
            if isinstance(module, torch.nn.Conv2d):
                prune.l1_unstructured(module, name='weight', amount=0.2)
                prune.remove(module, 'weight')
                print("weight remove")
            elif isinstance(module, torch.nn.Linear):
                prune.l1_unstructured(module, name='weight', amount=0.4)
                prune.remove(module, 'weight')

        print(dict(new_model.named_buffers()).keys()
              )  # to verify that all masks exist

        print(
            "after------------------------------------------------------------"
        )
        total_params = sum(p.numel() for p in new_model.parameters())
        print("num of parameter : ", total_params)
        trainable_params = sum(p.numel() for p in new_model.parameters()
                               if p.requires_grad)
        print("num of trainable_ parameter :", trainable_params)
        print("------------------------------------------------------------")

        print(
            "after------------------------------------------------------------"
        )
        print(new_model.state_dict().keys())

        save_pruning_model('./weights/{}'.format("pruning"), new_model)

    elif mode == 'test':
        gogo = os.path.join(
            '/tf/notebooks/08_road_condition_변유철/dataset/08_road_condition/test/'
        )
        dataset_test = dataloader.make_testset(gogo)

        data_loader_test = torch.utils.data.DataLoader(
            dataset_test,
            batch_size=1,
            shuffle=False,
            collate_fn=dataloader.collate_fn)

        load_model(model_name, new_model)

        print("------------------------------------------------------------")
        total_params = sum(p.numel() for p in new_model.parameters())
        print("num of parameter : ", total_params)
        trainable_params = sum(p.numel() for p in new_model.parameters()
                               if p.requires_grad)
        print("num of trainable_ parameter :", trainable_params)
        print("------------------------------------------------------------")
        print(new_model.eval())

        test(new_model, data_loader_test, device, prediction_file)

    print("That's it gogo!")