Beispiel #1
0
def main():
    ##############################################################################
    if args.server == 'server_A':
        work_dir = os.path.join('/data1/JM/lung-seg-back-up', args.exp)
        print(work_dir)
    elif args.server == 'server_B':
        work_dir = os.path.join('/data1/workspace/JM_gen/lung-seg-back-up',
                                args.exp)
        print(work_dir)
    elif args.server == 'server_D':
        work_dir = os.path.join(
            '/daintlab/home/woans0104/workspace/'
            'lung-seg-back-up', args.exp)

        print(work_dir)
    ##############################################################################

    if not os.path.exists(work_dir):
        os.makedirs(work_dir)

    # copy this file to work dir to keep training configuration
    shutil.copy(__file__, os.path.join(work_dir, 'main.py'))
    with open(os.path.join(work_dir, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    source_dataset, target_dataset1, target_dataset2 \
        = loader.dataset_condition(args.source_dataset)

    # 1.load_dataset
    train_loader_source, test_loader_source \
        = loader.get_loader(server=args.server,
                            dataset=source_dataset,
                            train_size=args.train_size,
                            aug_mode=args.aug_mode,
                            aug_range=args.aug_range,
                            batch_size=args.batch_size,
                            work_dir=work_dir)

    train_loader_target1, _ = loader.get_loader(server=args.server,
                                                dataset=target_dataset1,
                                                train_size=1,
                                                aug_mode=False,
                                                aug_range=args.aug_range,
                                                batch_size=1,
                                                work_dir=work_dir)
    train_loader_target2, _ = loader.get_loader(server=args.server,
                                                dataset=target_dataset2,
                                                train_size=1,
                                                aug_mode=False,
                                                aug_range=args.aug_range,
                                                batch_size=1,
                                                work_dir=work_dir)

    test_data_li = [
        test_loader_source, train_loader_target1, train_loader_target2
    ]

    trn_logger = Logger(os.path.join(work_dir, 'train.log'))
    trn_raw_logger = Logger(os.path.join(work_dir, 'train_raw.log'))
    val_logger = Logger(os.path.join(work_dir, 'validation.log'))

    trn_logger_ae = Logger(os.path.join(work_dir, 'ae_train.log'))
    val_logger_ae = Logger(os.path.join(work_dir, 'ae_validation.log'))

    # 2.model_select
    model_seg = Unet2D(in_shape=(1, 256, 256))
    model_seg = model_seg.cuda()
    model_ae = ae_lung(in_shape=(1, 256, 256))
    model_ae = model_ae.cuda()

    cudnn.benchmark = True

    # 3.gpu select
    model_seg = nn.DataParallel(model_seg)
    model_ae = nn.DataParallel(model_ae)

    # 4.optim
    if args.optim == 'adam':
        optimizer_seg = torch.optim.Adam(model_seg.parameters(),
                                         betas=(args.adam_beta1, 0.999),
                                         eps=args.eps,
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)

        optimizer_ae = torch.optim.Adam(model_ae.parameters(),
                                        betas=(args.adam_beta1, 0.999),
                                        eps=args.eps,
                                        lr=args.lr,
                                        weight_decay=args.weight_decay)
    elif args.optim == 'adamp':
        optimizer_seg = AdamP(model_seg.parameters(),
                              betas=(args.adam_beta1, 0.999),
                              eps=args.eps,
                              lr=args.lr,
                              weight_decay=args.weight_decay)

        optimizer_ae = AdamP(model_ae.parameters(),
                             betas=(args.adam_beta1, 0.999),
                             eps=args.eps,
                             lr=args.lr,
                             weight_decay=args.weight_decay)

    elif args.optim == 'sgd':
        optimizer_seg = torch.optim.SGD(model_seg.parameters(),
                                        lr=args.lr,
                                        weight_decay=args.weight_decay)

        optimizer_ae = torch.optim.SGD(model_ae.parameters(),
                                       lr=args.lr,
                                       weight_decay=args.weight_decay)

    # lr decay
    lr_schedule = args.lr_schedule
    lr_scheduler_seg = optim.lr_scheduler.MultiStepLR(
        optimizer_seg, milestones=lr_schedule[:-1], gamma=0.1)

    lr_scheduler_ae = optim.lr_scheduler.MultiStepLR(
        optimizer_ae, milestones=lr_schedule[:-1], gamma=0.1)

    # 5.loss

    criterion_seg = select_loss(args.seg_loss_function)
    criterion_ae = select_loss(args.ae_loss_function)
    criterion_embedding = select_loss(args.embedding_loss_function)

    ############################################################################
    # train

    best_iou = 0
    try:
        if args.train_mode:
            for epoch in range(lr_schedule[-1]):

                train(model_seg=model_seg,
                      model_ae=model_ae,
                      train_loader=train_loader_source,
                      epoch=epoch,
                      criterion_seg=criterion_seg,
                      criterion_ae=criterion_ae,
                      criterion_embedding=criterion_embedding,
                      optimizer_seg=optimizer_seg,
                      optimizer_ae=optimizer_ae,
                      logger=trn_logger,
                      sublogger=trn_raw_logger,
                      logger_ae=trn_logger_ae)

                iou = validate(model_seg=model_seg,
                               model_ae=model_ae,
                               val_loader=test_loader_source,
                               epoch=epoch,
                               criterion_seg=criterion_seg,
                               criterion_ae=criterion_ae,
                               logger=val_logger,
                               logger_ae=val_logger_ae)

                print('validation result ************************************')

                lr_scheduler_seg.step()
                lr_scheduler_ae.step()

                if args.val_size == 0:
                    is_best = 1
                else:
                    is_best = iou > best_iou
                best_iou = max(iou, best_iou)
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model_seg.state_dict(),
                        'optimizer': criterion_seg.state_dict()
                    },
                    is_best,
                    work_dir,
                    filename='checkpoint.pth')

        print("train end")
    except RuntimeError as e:
        print(
            '#jm_private',
            '-----------------------------------  error train : '
            'send to message JM  '
            '& Please send a kakao talk -------------------------- '
            '\n error message : {}'.format(e))

        import ipdb
        ipdb.set_trace()

    draw_curve(work_dir, trn_logger, val_logger)
    draw_curve(work_dir, trn_logger_ae, val_logger_ae, labelname='ae')

    # here is load model for last pth
    check_best_pth(work_dir)

    # validation
    if args.test_mode:
        print('Test mode ...')
        main_test(model=model_seg, test_loader=test_data_li, args=args)
Beispiel #2
0
def main():
    # save input stats for later use

    print(args.work_dir, args.exp)
    work_dir = os.path.join(args.work_dir, args.exp)
    if not os.path.exists(work_dir):
        os.makedirs(work_dir)

    # copy this file to work dir to keep training configuration
    shutil.copy(__file__, os.path.join(work_dir, 'main.py'))
    with open(os.path.join(work_dir, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    # transform
    transform1 = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([0.5], [0.5])])

    # 1.train_dataset

    train_path, test_path = loader.make_dataset(args.train_site,
                                                train_size=args.train_size,
                                                mode='train')

    np.save(os.path.join(work_dir, '{}_test_path.npy'.format(args.train_site)),
            test_path)

    train_image_path = train_path[0]
    train_label_path = train_path[1]
    test_image_path = test_path[0]
    test_label_path = test_path[1]

    train_dataset = loader.CustomDataset(train_image_path,
                                         train_label_path,
                                         args.train_site,
                                         args.input_size,
                                         transform1,
                                         arg_mode=args.arg_mode,
                                         arg_thres=args.arg_thres)
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=4)

    val_dataset = loader.CustomDataset(test_image_path,
                                       test_label_path,
                                       args.train_site,
                                       args.input_size,
                                       transform1,
                                       arg_mode=False)
    val_loader = data.DataLoader(val_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=4)

    Train_test_dataset = loader.CustomDataset(test_image_path, test_label_path,
                                              args.train_site, args.input_size,
                                              transform1)
    Train_test_loader = data.DataLoader(Train_test_dataset,
                                        batch_size=1,
                                        shuffle=True,
                                        num_workers=4)

    trn_logger = Logger(os.path.join(work_dir, 'train.log'))
    trn_raw_logger = Logger(os.path.join(work_dir, 'train_raw.log'))
    val_logger = Logger(os.path.join(work_dir, 'validation.log'))

    # 3.model_select
    my_net, model_name = model_select(
        args.arch,
        args.input_size,
    )

    # 4.gpu select
    my_net = nn.DataParallel(my_net).cuda()
    cudnn.benchmark = True

    # 5.optim

    if args.optim == 'adam':
        gen_optimizer = torch.optim.Adam(my_net.parameters(),
                                         lr=args.initial_lr,
                                         eps=args.eps)
    elif args.optim == 'sgd':
        gen_optimizer = torch.optim.SGD(my_net.parameters(),
                                        lr=args.initial_lr,
                                        momentum=0.9,
                                        weight_decay=args.weight_decay)

    # lr decay
    lr_schedule = args.lr_schedule
    lr_scheduler = optim.lr_scheduler.MultiStepLR(gen_optimizer,
                                                  milestones=lr_schedule[:-1],
                                                  gamma=args.gamma)

    # 6.loss
    if args.loss_function == 'bce':
        criterion = nn.BCEWithLogitsLoss(
            pos_weight=torch.Tensor([args.bce_weight])).cuda()
    elif args.loss_function == 'mse':
        criterion = nn.MSELoss().cuda()


#####################################################################################

# train

    send_slack_message(args.token, '#jm_private',
                       '{} : starting_training'.format(args.exp))
    best_iou = 0
    try:
        if args.train_mode:
            for epoch in range(lr_schedule[-1]):

                train(my_net, train_loader, gen_optimizer, epoch, criterion,
                      trn_logger, trn_raw_logger)
                iou = validate(val_loader,
                               my_net,
                               criterion,
                               epoch,
                               val_logger,
                               save_fig=False,
                               work_dir_name='jsrt_visualize_per_epoch')
                print(
                    'validation_iou **************************************************************'
                )

                lr_scheduler.step()

                if args.val_size == 0:
                    is_best = 1
                else:
                    is_best = iou > best_iou
                best_iou = max(iou, best_iou)
                checkpoint_filename = 'model_checkpoint_{:0>3}.pth'.format(
                    epoch + 1)
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': my_net.state_dict(),
                        'optimizer': gen_optimizer.state_dict()
                    },
                    is_best,
                    work_dir,
                    filename='checkpoint.pth')

        print("train end")
    except RuntimeError as e:
        send_slack_message(
            args.token, '#jm_private',
            '-----------------------------------  error train : send to message JM  & Please send a kakao talk ----------------------------------------- \n error message : {}'
            .format(e))
        import ipdb
        ipdb.set_trace()

    draw_curve(work_dir, trn_logger, val_logger)
    send_slack_message(args.token, '#jm_private',
                       '{} : end_training'.format(args.exp))

    if args.test_mode:
        print('Test mode ...')
        main_test(model=my_net, test_loader=test_data_list, args=args)
Beispiel #3
0
def main():

    work_dir = os.path.join(args.work_dir, args.exp)
    if not os.path.exists(work_dir):
        os.makedirs(work_dir)

    # copy this file to work dir to keep training configuration
    shutil.copy(__file__, os.path.join(work_dir, 'main.py'))
    with open(os.path.join(work_dir, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    #train
    trn_image_root = os.path.join(args.trn_root, 'images')
    exam_ids = os.listdir(trn_image_root)
    random.shuffle(exam_ids)
    train_exam_ids = exam_ids

    #train_exam_ids = exam_ids[:int(len(exam_ids)*0.8)]
    #val_exam_ids = exam_ids[int(len(exam_ids) * 0.8):]

    # train_dataset
    trn_dataset = DatasetTrain(args.trn_root,
                               train_exam_ids,
                               options=args,
                               input_stats=[0.5, 0.5])
    trn_loader = torch.utils.data.DataLoader(trn_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.num_workers)

    # save input stats for later use
    np.save(os.path.join(work_dir, 'input_stats.npy'), trn_dataset.input_stats)

    #val
    val_image_root = os.path.join(args.val_root, 'images')
    val_exam = os.listdir(val_image_root)
    random.shuffle(val_exam)
    val_exam_ids = val_exam

    # val_dataset
    val_dataset = DatasetVal(args.val_root,
                             val_exam_ids,
                             options=args,
                             input_stats=trn_dataset.input_stats)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.num_workers)

    # make logger
    trn_logger = Logger(os.path.join(work_dir, 'train.log'))
    trn_raw_logger = Logger(os.path.join(work_dir, 'train_raw.log'))
    val_logger = Logger(os.path.join(work_dir, 'validation.log'))

    # model_select
    if args.model == 'unet':
        net = UNet3D(1,
                     1,
                     f_maps=args.f_maps,
                     depth_stride=args.depth_stride,
                     conv_layer_order=args.conv_layer_order,
                     num_groups=args.num_groups)

    else:
        raise ValueError('Not supported network.')

    # loss_select
    if args.loss_function == 'bce':
        criterion = nn.BCEWithLogitsLoss(
            pos_weight=torch.Tensor([args.bce_weight])).cuda()
    elif args.loss_function == 'dice':
        criterion = DiceLoss().cuda()
    elif args.loss_function == 'weight_bce':
        criterion = nn.BCEWithLogitsLoss(
            pos_weight=torch.FloatTensor([5])).cuda()
    else:
        raise ValueError('{} loss is not supported yet.'.format(
            args.loss_function))

    # optim_select
    if args.optim == 'sgd':
        optimizer = optim.SGD(net.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay,
                              nesterov=False)

    elif args.optim == 'adam':
        optimizer = optim.Adam(net.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
    else:
        raise ValueError('{} optim is not supported yet.'.format(args.optim))

    net = nn.DataParallel(net).cuda()
    cudnn.benchmark = True

    lr_schedule = args.lr_schedule
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                  milestones=lr_schedule[:-1],
                                                  gamma=0.1)

    best_iou = 0
    for epoch in range(lr_schedule[-1]):

        train(trn_loader, net, criterion, optimizer, epoch, trn_logger,
              trn_raw_logger)
        iou = validate(val_loader, net, criterion, epoch, val_logger)

        lr_scheduler.step()

        # save model parameter
        is_best = iou > best_iou
        best_iou = max(iou, best_iou)
        checkpoint_filename = 'model_checkpoint_{:0>3}.pth'.format(epoch + 1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }, is_best, work_dir, checkpoint_filename)

    # visualize curve
    draw_curve(work_dir, trn_logger, val_logger)

    if args.inplace_test:
        # calc overall performance and save figures
        print('Test mode ...')
        main_test(model=net, args=args)
Beispiel #4
0
def main():
    if args.server == 'server_A':
        work_dir = os.path.join('/data1/JM/lung_segmentation', args.exp)
        print(work_dir)
    elif args.server == 'server_B':
        work_dir = os.path.join('/data1/workspace/JM_gen/lung_seg', args.exp)
        print(work_dir)

    if not os.path.exists(work_dir):
        os.makedirs(work_dir)

    # copy this file to work dir to keep training configuration
    shutil.copy(__file__, os.path.join(work_dir, 'main.py'))
    with open(os.path.join(work_dir, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    # transform
    transform1 = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([0.5], [0.5])])

    # 1.train_dataset
    if args.val_size == 0:
        train_path, test_path = loader.make_dataset(args.server,
                                                    args.train_dataset +
                                                    '_dataset',
                                                    train_size=args.train_size)

        np.save(
            os.path.join(work_dir,
                         '{}_test_path.npy'.format(args.train_dataset)),
            test_path)

        train_image_path = train_path[0]
        train_label_path = train_path[1]
        test_image_path = test_path[0]
        test_label_path = test_path[1]

        train_dataset = loader.CustomDataset(train_image_path,
                                             train_label_path,
                                             transform1,
                                             arg_mode=args.arg_mode,
                                             arg_thres=args.arg_thres,
                                             arg_range=args.arg_range,
                                             dataset=args.train_dataset)
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=4)

        # Organize images and labels differently.
        train_dataset_random = loader.CustomDataset(train_image_path,
                                                    train_label_path,
                                                    transform1,
                                                    arg_mode=args.arg_mode,
                                                    arg_thres=args.arg_thres,
                                                    arg_range=args.arg_range,
                                                    dataset=args.train_dataset)
        train_loader_random = data.DataLoader(train_dataset_random,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=4)

        val_dataset = loader.CustomDataset(test_image_path,
                                           test_label_path,
                                           transform1,
                                           arg_mode=False,
                                           dataset=args.train_dataset)
        val_loader = data.DataLoader(val_dataset,
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=4)

        # 'JSRT' test_dataset
        Train_test_dataset = loader.CustomDataset(test_image_path,
                                                  test_label_path,
                                                  transform1,
                                                  dataset=args.train_dataset)
        Train_test_loader = data.DataLoader(Train_test_dataset,
                                            batch_size=1,
                                            shuffle=True,
                                            num_workers=4)

        # 2.test_dataset_path

        # 'MC'test_dataset
        test_data1_path, _ = loader.make_dataset(args.server,
                                                 args.test_dataset1 +
                                                 '_dataset',
                                                 train_size=1)
        test_data1_dataset = loader.CustomDataset(test_data1_path[0],
                                                  test_data1_path[1],
                                                  transform1,
                                                  dataset=args.test_dataset1)
        test_data1_loader = data.DataLoader(test_data1_dataset,
                                            batch_size=1,
                                            shuffle=True,
                                            num_workers=4)

        # 'sh'test_dataset
        test_data2_path, _ = loader.make_dataset(args.server,
                                                 args.test_dataset2 +
                                                 '_dataset',
                                                 train_size=1)
        test_data2_dataset = loader.CustomDataset(test_data2_path[0],
                                                  test_data2_path[1],
                                                  transform1,
                                                  dataset=args.test_dataset2)
        test_data2_loader = data.DataLoader(test_data2_dataset,
                                            batch_size=1,
                                            shuffle=True,
                                            num_workers=0)

        test_data_list = [
            Train_test_loader, test_data1_loader, test_data2_loader
        ]

        # np.save(os.path.join(work_dir, 'input_stats.npy'), train_dataset.input_stats)

        trn_logger = Logger(os.path.join(work_dir, 'train.log'))
        trn_raw_logger = Logger(os.path.join(work_dir, 'train_raw.log'))
        val_logger = Logger(os.path.join(work_dir, 'validation.log'))

    # 3.model_select
    model_seg, model_name = model_select(args.arch_seg)
    model_ae, _ = model_select(args.arch_ae)

    # 4.gpu select
    model_seg = nn.DataParallel(model_seg).cuda()
    model_ae = nn.DataParallel(model_ae).cuda()

    #model_seg = model_seg.cuda()
    #model_ae = model_ae.cuda()

    cudnn.benchmark = True

    # 5.optim
    if args.optim == 'adam':
        optimizer_seg = torch.optim.Adam(model_seg.parameters(),
                                         lr=args.initial_lr)
        optimizer_ae = torch.optim.Adam(model_ae.parameters(),
                                        lr=args.initial_lr)

    elif args.optim == 'sgd':
        optimizer_seg = torch.optim.SGD(model_seg.parameters(),
                                        lr=args.initial_lr,
                                        weight_decay=args.weight_decay)

        optimizer_ae = torch.optim.SGD(model_ae.parameters(),
                                       lr=args.initial_lr,
                                       weight_decay=args.weight_decay)

    # if args.clip_grad :
    #
    #     import torch.nn.utils as torch_utils
    #     max_grad_norm = 1.
    #
    #     torch_utils.clip_grad_norm_(model_seg.parameters(),
    #                                 max_grad_norm
    #                                 )
    #     torch_utils.clip_grad_norm_(model_ae.parameters(),
    #                                 max_grad_norm
    #                                 )

    # lr decay
    lr_schedule = args.lr_schedule
    lr_scheduler_seg = optim.lr_scheduler.MultiStepLR(
        optimizer_seg, milestones=lr_schedule[:-1], gamma=args.gamma)

    lr_scheduler_ae = optim.lr_scheduler.MultiStepLR(
        optimizer_ae, milestones=lr_schedule[:-1], gamma=args.gamma)

    # 6.loss

    criterion_seg = loss_function_select(args.seg_loss_function)
    criterion_ae = loss_function_select(args.ae_loss_function)
    criterion_embedding = loss_function_select(args.embedding_loss_function)

    #####################################################################################

    # train

    send_slack_message('#jm_private',
                       '{} : starting_training'.format(args.exp))
    best_iou = 0
    try:
        if args.train_mode:
            for epoch in range(lr_schedule[-1]):

                train(model_seg=model_seg,
                      model_ae=model_ae,
                      train_loader=train_loader,
                      train_loder_random=train_loader_random,
                      optimizer_seg=optimizer_seg,
                      optimizer_ae=optimizer_ae,
                      criterion_seg=criterion_seg,
                      criterion_ae=criterion_ae,
                      criterion_embedding=criterion_embedding,
                      epoch=epoch,
                      logger=trn_logger,
                      sublogger=trn_raw_logger)

                iou = validate(model=model_seg,
                               val_loader=val_loader,
                               criterion=criterion_seg,
                               epoch=epoch,
                               logger=val_logger,
                               work_dir=work_dir,
                               save_fig=False,
                               work_dir_name='{}_visualize_per_epoch'.format(
                                   args.train_dataset))
                print(
                    'validation result **************************************************************'
                )

                lr_scheduler_seg.step()
                lr_scheduler_ae.step()

                if args.val_size == 0:
                    is_best = 1
                else:
                    is_best = iou > best_iou

                best_iou = max(iou, best_iou)
                checkpoint_filename = 'model_checkpoint_{:0>3}.pth'.format(
                    epoch + 1)
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model_seg.state_dict(),
                        'optimizer': optimizer_seg.state_dict()
                    },
                    is_best,
                    work_dir,
                    filename='checkpoint.pth')

        print("train end")
    except RuntimeError as e:
        send_slack_message(
            '#jm_private',
            '-----------------------------------  error train : send to message JM  & Please send a kakao talk ----------------------------------------- \n error message : {}'
            .format(e))

        import ipdb
        ipdb.set_trace()

    draw_curve(work_dir, trn_logger, val_logger)
    send_slack_message('#jm_private', '{} : end_training'.format(args.exp))
    #--------------------------------------------------------------------------------------------------------#
    #here is load model for last pth
    load_filename = os.path.join(work_dir, 'model_best.pth')
    checkpoint = torch.load(load_filename)
    ch_epoch = checkpoint['epoch']
    save_check_txt = os.path.join(work_dir, str(ch_epoch))
    f = open("{}_best_checkpoint.txt".format(save_check_txt), 'w')
    f.close()

    # --------------------------------------------------------------------------------------------------------#

    # validation
    if args.test_mode:
        print('Test mode ...')
        main_test(model=model_seg, test_loader=test_data_list, args=args)
def main():
    # save input stats for later use

    if args.server == 'server_A':
        work_dir = os.path.join('/data1/JM/lung_segmentation', args.exp)
        print(work_dir)
    elif args.server == 'server_B':
        work_dir = os.path.join('/data1/workspace/JM_gen/lung-seg-back-up',
                                args.exp)
        print(work_dir)
    elif args.server == 'server_D':
        work_dir = os.path.join(
            '/daintlab/home/woans0104/workspace/'
            'lung-seg-back-up', args.exp)
        print(work_dir)
    if not os.path.exists(work_dir):
        os.makedirs(work_dir)

    # copy this file to work dir to keep training configuration
    shutil.copy(__file__, os.path.join(work_dir, 'main.py'))
    with open(os.path.join(work_dir, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    source_dataset, target_dataset1, target_dataset2 \
        = loader.dataset_condition(args.source_dataset)

    # 1.load_dataset
    train_loader_source,test_loader_source \
        = loader.get_loader(server=args.server,
                            dataset=source_dataset,
                            train_size=args.train_size,
                            aug_mode=args.aug_mode,
                            aug_range=args.aug_range,
                            batch_size=args.batch_size,
                            work_dir=work_dir)

    train_loader_target1, _ = loader.get_loader(server=args.server,
                                                dataset=target_dataset1,
                                                train_size=1,
                                                aug_mode=False,
                                                aug_range=args.aug_range,
                                                batch_size=1,
                                                work_dir=work_dir)

    train_loader_target2, _ = loader.get_loader(server=args.server,
                                                dataset=target_dataset2,
                                                train_size=1,
                                                aug_mode=False,
                                                aug_range=args.aug_range,
                                                batch_size=1,
                                                work_dir=work_dir)

    test_data_li = [
        test_loader_source, train_loader_target1, train_loader_target2
    ]

    trn_logger = Logger(os.path.join(work_dir, 'train.log'))
    trn_raw_logger = Logger(os.path.join(work_dir, 'train_raw.log'))
    val_logger = Logger(os.path.join(work_dir, 'validation.log'))

    # 2.model_select
    #model_seg = select_model(args.arch)

    if args.arch == 'unet':
        model_seg = Unet2D(in_shape=(1, 256, 256))
    elif args.arch == 'unet_norm':
        model_seg = Unet2D_norm(in_shape=(1, 256, 256),
                                nomalize_con=args.nomalize_con,
                                affine=args.affine,
                                group_channel=args.group_channel,
                                weight_std=args.weight_std)

    else:
        raise ValueError('Not supported network.')

    model_seg = model_seg.cuda()

    # 3.gpu select
    model_seg = nn.DataParallel(model_seg)
    cudnn.benchmark = True

    # 4.optim

    if args.optim == 'adam':
        optimizer_seg = torch.optim.Adam(model_seg.parameters(),
                                         lr=args.lr,
                                         weight_decay=args.weight_decay,
                                         eps=args.eps)

    elif args.optim == 'adamp':
        optimizer_seg = AdamP(model_seg.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay,
                              eps=args.eps)
    elif args.optim == 'sgd':
        optimizer_seg = torch.optim.SGD(model_seg.parameters(),
                                        lr=args.lr,
                                        momentum=0.9,
                                        weight_decay=args.weight_decay)

    # lr decay
    lr_schedule = args.lr_schedule
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer_seg,
                                                  milestones=lr_schedule[:-1],
                                                  gamma=0.1)

    # 5.loss

    if args.loss_function == 'bce':
        criterion = nn.BCELoss()
    elif args.loss_function == 'bce_logit':
        criterion = nn.BCEWithLogitsLoss()
    elif args.loss_function == 'dice':
        criterion = DiceLoss()
    elif args.loss_function == 'Cldice':
        bce = nn.BCEWithLogitsLoss().cuda()
        dice = DiceLoss().cuda()
        criterion = ClDice(bce, dice, alpha=1, beta=1)

    criterion = criterion.cuda()

    ###############################################################################

    # train

    best_iou = 0
    try:
        if args.train_mode:
            for epoch in range(lr_schedule[-1]):

                train(model=model_seg,
                      train_loader=train_loader_source,
                      epoch=epoch,
                      criterion=criterion,
                      optimizer=optimizer_seg,
                      logger=trn_logger,
                      sublogger=trn_raw_logger)

                iou = validate(model=model_seg,
                               val_loader=test_loader_source,
                               epoch=epoch,
                               criterion=criterion,
                               logger=val_logger)
                print('validation_result ************************************')

                lr_scheduler.step()

                if args.val_size == 0:
                    is_best = 1
                else:
                    is_best = iou > best_iou
                best_iou = max(iou, best_iou)
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model_seg.state_dict(),
                        'optimizer': optimizer_seg.state_dict()
                    },
                    is_best,
                    work_dir,
                    filename='checkpoint.pth')

        print("train end")
    except RuntimeError as e:
        print('error message : {}'.format(e))

        import ipdb
        ipdb.set_trace()

    draw_curve(work_dir, trn_logger, val_logger)
    # here is load model for last pth
    check_best_pth(work_dir)

    # validation
    if args.test_mode:
        print('Test mode ...')
        main_test(model=model_seg, test_loader=test_data_li, args=args)