Exemplo n.º 1
0
def cross_validation(train_df, y, test_df, params, categorical_features=[]):

    fit_params = {
        'num_boost_round': 10000,
        'early_stopping_rounds': 200,
        'verbose_eval': 200,
    }

    print(fit_params)
    feat_name = [*train_df.columns]

    # 交差検証
    kf = StratifiedKFold(n_splits=5, random_state=0, shuffle=True)
    y_oof = np.empty([len(train_df),])
    y_test = []
    feature_importances = pd.DataFrame()
    for fold, (train_idx, valid_idx) in enumerate(kf.split(train_df, y.astype(int))):
        print('Fold {}'.format(fold + 1))

        x_train, y_train = train_df.iloc[train_idx][feat_name], y.iloc[train_idx]
        x_val, y_val = train_df.iloc[valid_idx][feat_name], y.iloc[valid_idx]
        x_test = test_df[feat_name]

        y_pred_valid, y_pred_test, valid_loss, importances, best_iter = trainers.train_lgbm(
            x_train, y_train, x_val, y_val, x_test, params, fit_params, 
            categorical_features=categorical_features,
            feature_name=feat_name,
            fold_id=fold,
            loss_func=loss.calc_loss,
            calc_importances=True
        )

        y_oof[valid_idx] = y_pred_valid
        score = loss.calc_loss(y[valid_idx], y_pred_valid)
        y_test.append(y_pred_test)
        feature_importances = pd.concat([feature_importances, importances], axis=0, sort=False)

    # feature_importances.to_csv(output_path.joinpath("feature_importances.csv"), index=False)
    # feature_importances.groupby("feature", as_index=False).mean().to_csv(output_path.joinpath("feature_importances_cvmean.csv"), index=False)

    # validのスコア計算
    score = loss.calc_loss(y, y_oof)
    print(f"valid score: {score}")

    # submission用(CVの平均を結果とする)
    y_test = np.mean(y_test,axis=0)

    return score, y_test
Exemplo n.º 2
0
def train_unet(model,
               dataloaders,
               optimizer,
               scheduler,
               temporal_model_dir,
               unet_device,
               bce_weight=0.25,
               num_epochs=10):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print('LR', param_group['lr'])
                model.train()
            else:
                model.eval()

            metrics = defaultdict(float)
            epoch_samples = 0

            for inputs, labels, _, _ in tqdm(dataloaders[phase]):
                inputs = inputs.to(unet_device)
                labels = labels.to(unet_device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    print(inputs.get_device())
                    print(labels.get_device())
                    print(outputs.get_device())
                    loss, _ = calc_loss(outputs,
                                        labels,
                                        metrics,
                                        bce_weight=bce_weight)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                epoch_samples += inputs.size(0)
            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['dice'] / epoch_samples
            if phase == 'val' and epoch_loss < best_loss:
                print('best loss changed!')
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
    model.load_state_dict(best_model_wts)
    return model
def train(epoch, iter_start=0):
    netHg.train()

    global global_step
    pbar = tqdm.tqdm(train_loader, desc='Epoch %02d' % epoch, dynamic_ncols=True)
    pbar_info = tqdm.tqdm(bar_format='{bar}{postfix}')
    for it, sample in enumerate(pbar, start=iter_start):
        global_step += 1
        if FLAGS.debug:
            image, masks, keypoints, heatmaps, img_ids = sample
        else:
            image, masks, keypoints, heatmaps = sample
        image = Variable(image)
        masks = Variable(masks)
        keypoints = Variable(keypoints)
        heatmaps = Variable(heatmaps)
        if FLAGS.cuda:
            image = image.cuda(async=FLAGS.pinMem)
            masks = masks.cuda(async=FLAGS.pinMem)
            keypoints = keypoints.cuda(async=FLAGS.pinMem)
            heatmaps = heatmaps.cuda(async=FLAGS.pinMem)

        outputs = netHg(image)
        push_loss, pull_loss, detection_loss = calc_loss(outputs, keypoints, heatmaps, masks)

        loss_hg = 0
        toprint = ''
        sum_dict = {}
        for loss, weight, name in zip([push_loss, pull_loss, detection_loss], [1e-3, 1e-3, 1],
                                      ['push_loss', 'pull_loss', 'detection_loss']):
            loss_temp = torch.mean(loss)
            sum_dict[name] = getValue(loss_temp)
            loss_temp *= weight
            loss_hg += loss_temp
            toprint += '{:.8f} '.format(getValue(loss_temp))

        optimHg.zero_grad()
        loss_hg.backward()
        optimHg.step()

        # Summary
        sumWriter.add_scalar('loss_hg', loss_hg, global_step)
        for key, value in sum_dict.items():
            sumWriter.add_scalar(key, loss_temp, global_step)

        pbar_info.set_postfix_str(toprint)
        pbar_info.update()

    pbar.close()
    pbar_info.close()
Exemplo n.º 4
0
def test_seg_model(args):
    if args.model_name == "UNet":
        model = UNet(n_channels=args.in_channels, n_classes=args.class_num)
    elif args.model_name == "PSP":
        model = pspnet.PSPNet(n_classes=19, input_size=(448, 448))
        model.classification = nn.Conv2d(512, args.class_num, kernel_size=1)
    else:
        raise NotImplemented("Unknown model {}".format(args.model_name))
    model_path = os.path.join(args.model_dir, args.best_model)
    model = nn.DataParallel(model)
    model.load_state_dict(torch.load(model_path))
    model.cuda()
    model.eval()

    print('--------Start testing--------')
    since = time.time()
    dloader = gen_dloader(os.path.join(args.data_dir, "val"),
                          args.batch_size,
                          mode="val")

    metrics = defaultdict(float)
    ttl_samples = 0

    # preds_dir = os.path.join(args.data_dir, "val/preds", args.model_name)
    # filesystem.overwrite_dir(preds_dir)
    for batch_ind, (imgs, masks) in enumerate(dloader):
        if batch_ind != 0 and batch_ind % 100 == 0:
            print("Processing {}/{}".format(batch_ind, len(dloader)))
        inputs = Variable(imgs.cuda())
        masks = Variable(masks.cuda())

        with torch.no_grad():
            outputs = model(inputs)
            loss = calc_loss(outputs, masks, metrics)
            # result_img = gen_patch_pred(inputs, masks, outputs)
            # result_path = os.path.join(preds_dir, str(uuid.uuid1())[:8] + ".png")
            # io.imsave(result_path, result_img)

        ttl_samples += inputs.size(0)
    avg_dice = metrics['dice'] / ttl_samples
    time_elapsed = time.time() - since
    print('Testing takes {:.0f}m {:.2f}s'.format(time_elapsed // 60,
                                                 time_elapsed % 60))
    print("----Dice coefficient is: {:.3f}".format(avg_dice))
Exemplo n.º 5
0
def train_seg_model(args):
    # model
    model = None
    if args.model_name == "UNet":
        model = UNet(n_channels=args.in_channels, n_classes=args.class_num)
    elif args.model_name == "PSP":
        model = pspnet.PSPNet(n_classes=19, input_size=(512, 512))
        model.load_pretrained_model(
            model_path="./segnet/pspnet/pspnet101_cityscapes.caffemodel")
        model.classification = nn.Conv2d(512, args.class_num, kernel_size=1)
    else:
        raise AssertionError("Unknow modle: {}".format(args.model_name))
    model = nn.DataParallel(model)
    model.cuda()
    # optimizer
    optimizer = None
    if args.optim_name == "Adam":
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=1.0e-3)
    elif args.optim_name == "SGD":
        optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              lr=args.init_lr,
                              momentum=0.9,
                              weight_decay=0.0005)
    else:
        raise AssertionError("Unknow optimizer: {}".format(args.optim_name))
    scheduler = lr_scheduler.LambdaLR(optimizer,
                                      lr_lambda=LambdaLR(args.maxepoch, 0,
                                                         0).step)
    # dataloader
    train_data_dir = os.path.join(args.data_dir, args.tumor_type, "train")
    train_dloader = gen_dloader(train_data_dir,
                                args.batch_size,
                                mode="train",
                                normalize=args.normalize,
                                tumor_type=args.tumor_type)
    test_data_dir = os.path.join(args.data_dir, args.tumor_type, "val")
    val_dloader = gen_dloader(test_data_dir,
                              args.batch_size,
                              mode="val",
                              normalize=args.normalize,
                              tumor_type=args.tumor_type)

    # training
    save_model_dir = os.path.join(args.model_dir, args.tumor_type,
                                  args.session)
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
    best_dice = 0.0
    for epoch in np.arange(0, args.maxepoch):
        print('Epoch {}/{}'.format(epoch + 1, args.maxepoch))
        print('-' * 10)
        since = time.time()
        for phase in ['train', 'val']:
            if phase == 'train':
                dloader = train_dloader
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("Current LR: {:.8f}".format(param_group['lr']))
                model.train()  # Set model to training mode
            else:
                dloader = val_dloader
                model.eval()  # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            for batch_ind, (imgs, masks) in enumerate(dloader):
                inputs = Variable(imgs.cuda())
                masks = Variable(masks.cuda())
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs,
                                     masks,
                                     metrics,
                                     bce_weight=args.bce_weight)
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                # statistics
                epoch_samples += inputs.size(0)
            print_metrics(metrics, epoch_samples, phase)
            epoch_dice = metrics['dice'] / epoch_samples

            # deep copy the model
            if phase == 'val' and (epoch_dice > best_dice
                                   or epoch > args.maxepoch - 5):
                best_dice = epoch_dice
                best_model = copy.deepcopy(model.state_dict())
                best_model_name = "-".join([
                    args.model_name,
                    "{:03d}-{:.3f}.pth".format(epoch, best_dice)
                ])
                torch.save(best_model,
                           os.path.join(save_model_dir, best_model_name))
        time_elapsed = time.time() - since
        print('Epoch {:2d} takes {:.0f}m {:.0f}s'.format(
            epoch, time_elapsed // 60, time_elapsed % 60))
    print(
        "================================================================================"
    )
    print("Training finished...")
Exemplo n.º 6
0
def main(args):
    #******************************************** load args
    batch_size = args.batch_size
    data_record_dir = args.tfrecord_dir
    log_dir = args.log_dir
    sfam_fg = args.sfam
    box_loss_scale = args.scale
    model_dir = args.model_dir
    load_num = args.load_num
    epoches = args.epoches
    save_weight_period = args.save_weight_period
    gpu_list = [int(i) for i in args.gpu.split(',')]
    gpu_num = len(gpu_list)
    #********************************************creat logging
    logger = logging.getLogger()
    hdlr = logging.FileHandler(
        os.path.join(
            log_dir,
            time.strftime('%F-%T', time.localtime()).replace(':', '-') +
            '.log'))
    #formatter = logging.Formatter('[%(asctime)s] [%(levelname)s] [%(threadName)-10s] %(message)s')
    #hdlr.setFormatter(formatter)
    logger.addHandler(hdlr)
    logger.addHandler(logging.StreamHandler())
    logger.setLevel(logging.INFO)
    logger.info("train gpu:{}".format(gpu_list))
    #********************************************load data
    y_true_size = cfgs.ClsNum + 6
    with tf.name_scope("Load_data"):
        tfrd = Read_Tfrecord(cfgs.DataSet_Name, data_record_dir, batch_size,
                             1000)
        num_obj_batch, img_batch, gtboxes_label_batch = tfrd.next_batch()
        anchors = tf.py_func(generate_anchors, [], tf.float32)
        anchors.set_shape([None, 4])
        x_batch, y_true = tf.py_func(process_imgs, [
            img_batch, gtboxes_label_batch, num_obj_batch, batch_size, anchors
        ], [tf.float32, tf.float32])
        x_batch.set_shape([None, cfgs.ImgSize, cfgs.ImgSize, 3])
        y_true.set_shape([None, cfgs.AnchorBoxes, y_true_size])
        images_s = tf.split(x_batch, num_or_size_splits=gpu_num, axis=0)
        labels_s = tf.split(y_true, num_or_size_splits=gpu_num, axis=0)
    #***************************************************************build trainer
    with tf.name_scope("optimizer"):
        global_step = tf.train.get_or_create_global_step()
        lr = tf.train.piecewise_constant(
            global_step,
            boundaries=[np.int64(x) for x in cfgs.DECAY_STEP],
            values=[y for y in cfgs.LR])
        #tf.summary.scalar('lr', lr)
        #opt = tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.9)
        opt = tf.train.AdamOptimizer(learning_rate=lr)
    #*****************************************************************get multi-model net
    # Calculate the gradients for each model tower.
    tower_grads = []
    loss_scalar_dict = {}
    loss_hist_dict = {}
    all_ave_cls_loss = tf.Variable(0.0,
                                   name='all_ave_cls_loss',
                                   trainable=False)
    all_ave_bbox_loss = tf.Variable(0.0,
                                    name='all_ave_bbox_loss',
                                    trainable=False)
    is_training = tf.constant(True)
    for i, idx in enumerate(gpu_list):
        with tf.variable_scope(tf.get_variable_scope(), reuse=i > 0):
            with tf.device('/gpu:%d' % idx):
                with tf.name_scope('%s_%d' % (args.tower_name, idx)) as scope:
                    net = M2Det(images_s[i], is_training, sfam_fg)
                    y_pred = net.prediction
                    #resue
                    tf.get_variable_scope().reuse_variables()
                    total_loss, bbox_loss, class_loss = calc_loss(
                        labels_s[i], y_pred, box_loss_weight=box_loss_scale)
                    loss_scalar_dict['cls_box/cb_loss_%d' % idx] = total_loss
                    ave_box_loss = tf.reduce_mean(bbox_loss)
                    all_ave_bbox_loss.assign_add(ave_box_loss)
                    ave_clss_loss = tf.reduce_mean(class_loss)
                    all_ave_cls_loss.assign_add(ave_clss_loss)
                    weights = [
                        v for v in tf.get_collection(
                            tf.GraphKeys.GLOBAL_VARIABLES)
                        if 'kernel' in v.name
                    ]
                    decay = tf.reduce_sum(
                        tf.stack([tf.nn.l2_loss(w) for w in weights])) * 5e-5
                    loss_scalar_dict['weight/weight_loss_%d' % idx] = decay
                    total_loss.assign_add(decay)
                    loss_scalar_dict['total/total_loss_%d' % idx] = total_loss
                    loss_scalar_dict['bbox/ave_box_loss_%d' %
                                     idx] = ave_box_loss
                    loss_scalar_dict['class/ave_class_loss_%d' %
                                     idx] = ave_clss_loss
                    loss_hist_dict['bbox/box_loss_%d' % idx] = bbox_loss
                    loss_hist_dict['class/class_loss_%d' % idx] = class_loss
                    grads = opt.compute_gradients(total_loss)
                    tower_grads.append(grads)
                    #tf.add_to_collection("total_loss",total_loss)
    #************************************************************************************compute gradients
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        grads = average_gradients(tower_grads)
        # Apply the gradients to adjust the shared variables.
        train_op = opt.apply_gradients(grads, global_step=global_step)
        #train_loss = tf.reduce_mean(tf.get_collection("total_loss"),0)
        #train_op = optimizer.minimize(train_loss,colocate_gradients_with_ops=True)
    #*****************************************************************************************add summary
    summaries = []
    # add grad histogram op
    for grad, var in grads:
        if grad is not None:
            summaries.append(
                tf.summary.histogram(var.op.name + '/gradients', grad))
    # add trainabel variable gradients
    for var in tf.trainable_variables():
        summaries.append(tf.summary.histogram(var.op.name, var))
    # add loss summary
    for keys, val in loss_scalar_dict.items():
        summaries.append(tf.summary.scalar(keys, val))
    for keys, val in loss_hist_dict.items():
        summaries.append(tf.summary.histogram(keys, val))
    # add learning rate
    summaries.append(tf.summary.scalar('leraning_rate', lr))
    summary_op = tf.summary.merge(summaries)
    #***********************************************************************************training
    with tf.name_scope("training_op"):
        tf_config = tf.ConfigProto()
        #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7)
        #tf_config.gpu_options = gpu_options
        tf_config.gpu_options.allow_growth = True
        tf_config.log_device_placement = False
        sess = tf.Session(config=tf_config)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord)
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        saver = tf.train.Saver(max_to_keep=10)
        #load model
        model_path = os.path.join(model_dir, cfgs.DataSet_Name)
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        model_path = os.path.join(model_path, cfgs.ModelPrefix)
        if load_num is not None:
            #assert tf.train.get_checkpoint_state(model_dir),'the params dictionary is not valid'
            model_path = "%s-%s" % (model_path, load_num)
            saver.restore(sess, model_path)
            logger.info('Resuming training %s' % model_path)
        # build summary
        #summary_op = tf.summary.merge_all()
        summary_path = os.path.join(log_dir, 'summary')
        if not os.path.exists(summary_path):
            os.makedirs(summary_path)
        summary_writer = tf.summary.FileWriter(summary_path, graph=sess.graph)
        # begin to tain
        try:
            for epoch_tmp in range(epoches):
                for step in range(
                        np.ceil(cfgs.Train_Num / batch_size).astype(np.int32)):
                    training_time = time.strftime('%Y-%m-%d %H:%M:%S',
                                                  time.localtime(time.time()))
                    global_value = sess.run(global_step)
                    if step % cfgs.SHOW_TRAIN_INFO != 0 and step % cfgs.SMRY_ITER != 0:
                        _ = sess.run([train_op])
                    else:
                        if step % cfgs.SHOW_TRAIN_INFO == 0:
                            _, loss_value, cur_lr, ave_box, ave_cls = sess.run(
                                [
                                    train_op, total_loss, opt._lr,
                                    all_ave_bbox_loss, all_ave_cls_loss
                                ])
                            logger.info(
                                '{} \t epoch:{}, lr:{}, step: {}, loss: {} , bbox:{}, cls:{}'
                                .format(str(training_time), epoch_tmp, cur_lr,
                                        global_value, loss_value,
                                        ave_box / gpu_num, ave_cls / gpu_num))
                        if step % cfgs.SMRY_ITER == 0:
                            _, summary_str = sess.run([train_op, summary_op])
                            summary_writer.add_summary(summary_str,
                                                       global_value)
                            summary_writer.flush()
                if (epoch_tmp > 0 and epoch_tmp % save_weight_period
                        == 0) or (epoch_tmp == epoches - 1):
                    dst = model_path
                    saver.save(sess, dst, epoch_tmp, write_meta_graph=False)
                    logger.info(">>*************** save weight ***: %d" %
                                epoch_tmp)
        except tf.errors.OutOfRangeError:
            print("Trianing is error")
        finally:
            coord.request_stop()
            summary_writer.close()
            coord.join(threads)
            #record_file_out.close()
            sess.close()
def test_unet(model, dataloaders, root_dir, test_subjects, save_dir,
              unet_device):
    model.eval()
    epoch_samples = 0
    metrics = defaultdict(float)

    result_path = os.path.join(save_dir, 'testing_result.txt')
    file = open(result_path, 'w')

    # load original shape
    GT_dir = os.path.join(root_dir, 'npy_labels')
    subject_seg = set_subject_result(GT_dir, test_subjects)
    with torch.no_grad():
        for inputs, labels, subject, name in tqdm(dataloaders):
            subject = subject[0]
            name = name[0].replace('.npy', '')

            inputs = inputs.to(unet_device)
            mean = inputs.mean()
            std = inputs.std()
            labels = labels.to(unet_device)

            outputs = model(inputs)
            thresh_outputs = F.sigmoid(outputs)
            thresh_outputs[thresh_outputs >= 0.5] = 1.0
            thresh_outputs[thresh_outputs < 0.5] = 0.0

            loss, dice = calc_loss(outputs, labels, metrics)
            pred = np.squeeze(thresh_outputs.cpu().data[0].numpy()).astype(
                np.float16)

            p_z, p_x, p_y = pred.shape
            _, z, x, y = name.split('_')
            z, x, y = int(z), int(x), int(y)
            subject_seg[subject][
                z:z + p_z, x:x + p_x, y:y +
                p_y] = (subject_seg[subject][z:z + p_z, x:x + p_x, y:y + p_y] +
                        pred) / 2

            epoch_samples += inputs.size(0)

    source_dir = '/media/NAS/nas_187/datasets/junghwan/experience/CT/TCIA/Labels'
    for key in subject_seg:
        original_path = os.path.join(source_dir, key)

        origin_label = os.path.join(source_dir, 'label{}.nii.gz'.format(key))
        origin_3D = sitk.ReadImage(origin_label, sitk.sitkInt16)

        subject_dir = os.path.join(save_dir, key)
        if not os.path.exists(subject_dir):
            os.mkdir(subject_dir)
        save_path = os.path.join(subject_dir, 'pred.nii')

        subject_seg[key][subject_seg[key] >= 0.4] = 1
        subject_seg[key][subject_seg[key] < 0.4] = 0

        result_3D = sitk.GetImageFromArray(subject_seg[key].astype(np.uint8))

        result_3D.SetDirection(origin_3D.GetDirection())
        result_3D.SetOrigin(origin_3D.GetOrigin())
        result_3D.SetSpacing(origin_3D.GetSpacing())

        sitk.WriteImage(result_3D, save_path)
        del result_3D

    for k in metrics.keys():
        file.write('{}: {:.4f}\n'.format(k, metrics[k] / epoch_samples))
    file.close()
    print_metrics(metrics, epoch_samples, 'test')
Exemplo n.º 8
0
def train_model():
    global global_step
    F_txt = open('./opt_results.txt', 'w')
    evaluator = Evaluator(args.n_cls)
    classes = ['road', 'others']
    writer = SummaryWriter(args.save_path + folder_path + '/log')

    def create_model(ema=False):
        model = hrnet18(pretrained=True).to(device)
        if ema:
            for param in model.parameters():
                param.detach_()
        return model

    model = create_model()
    ema_model = create_model(ema=True)
    # model = hrnet18(pretrained=True).to(device)
    # model = nn.DataParallel(model)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=0.0001)
    ema_optimizer = WeightEMA(model, ema_model, alpha=0.999)
    best_miou = 0.
    best_AA = 0.
    best_OA = 0.
    best_loss = 0.
    lr = args.lr
    epoch_index = 0
    if args.is_resume:
        args.resume = args.save_path + folder_path + '/checkpoint_fwiou.pth'
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            epoch_index = checkpoint['epoch']
            best_miou = checkpoint['miou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr = optimizer.param_groups[0]['lr']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            F_txt.write("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']) + '\n')
            # print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']), file=F_txt)
        else:
            print('EORRO: No such file!!!!!')

    TRAIN_DATA_DIRECTORY = args.root  # '/media/ws/www/IGARSS'
    TRAIN_DATA_LIST_PATH = args.train_list_path  # '/media/ws/www/unet_1/data/train.txt'

    VAL_DATA_DIRECTORY = args.root  # '/media/ws/www/IGARSS'
    VAL_DATA_LIST_PATH = args.val_list_path  # '/media/ws/www/unet_1/data/train.txt'

    dataloaders = {
        "train":
        DataLoader(GaofenTrain(TRAIN_DATA_DIRECTORY, TRAIN_DATA_LIST_PATH),
                   batch_size=args.batchsize,
                   shuffle=True,
                   num_workers=args.num_workers,
                   pin_memory=True,
                   drop_last=True),
        "val":
        DataLoader(GaofenVal(VAL_DATA_DIRECTORY, VAL_DATA_LIST_PATH),
                   batch_size=args.batchsize,
                   num_workers=args.num_workers,
                   pin_memory=True)
    }

    evaluator.reset()
    print('config: ' + folder_path)
    print('config: ' + folder_path, file=F_txt, flush=True)
    for epoch in range(epoch_index, args.num_epochs):
        print('Epoch [{}]/[{}] lr={:6f}'.format(epoch + 1, args.num_epochs,
                                                lr))
        # F_txt.write('Epoch [{}]/[{}] lr={:6f}'.format(epoch + 1, args.num_epochs, lr)+'\n',flush=True)
        print('Epoch [{}]/[{}] lr={:4f}'.format(epoch + 1, args.num_epochs,
                                                lr),
              file=F_txt,
              flush=True)
        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            evaluator.reset()
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                ema_model.eval()
                model.eval()  # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0

            for i, (inputs, labels, edge, _,
                    datafiles) in enumerate(tqdm(dataloaders[phase],
                                                 ncols=50)):
                inputs = inputs.to(device)
                edge = edge.to(device, dtype=torch.float)
                labels = labels.to(device, dtype=torch.long)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    if phase == 'train':
                        outputs = model(inputs)
                        outputs[1] = F.interpolate(input=outputs[1],
                                                   size=(labels.shape[1],
                                                         labels.shape[2]),
                                                   mode='bilinear',
                                                   align_corners=True)
                        loss = calc_loss(outputs, labels, edge, metrics)
                        pred = outputs[1].data.cpu().numpy()
                        pred = np.argmax(pred, axis=1)
                        labels = labels.data.cpu().numpy()
                        evaluator.add_batch(labels, pred)
                    if phase == 'val':
                        outputs = ema_model(inputs)
                        outputs[1] = F.interpolate(input=outputs[1],
                                                   size=(labels.shape[1],
                                                         labels.shape[2]),
                                                   mode='bilinear',
                                                   align_corners=True)
                        loss = calc_loss(outputs, labels, edge, metrics)
                        pred = outputs[1].data.cpu().numpy()
                        pred = np.argmax(pred, axis=1)
                        labels = labels.data.cpu().numpy()
                        evaluator.add_batch(labels, pred)
                    if phase == 'val' and (
                            epoch +
                            1) % args.vis_frequency == 0 and inputs.shape[
                                0] == args.batchsize:
                        for k in range(args.batchsize // 2):
                            name = datafiles['name'][k][:-4]

                            writer.add_image('{}/img'.format(name),
                                             cv2.cvtColor(
                                                 cv2.imread(
                                                     datafiles["img"][k],
                                                     cv2.IMREAD_COLOR),
                                                 cv2.COLOR_BGR2RGB),
                                             global_step=int((epoch + 1)),
                                             dataformats='HWC')

                            writer.add_image('{}/gt'.format(name),
                                             label_img_to_color(
                                                 labels[k])[:, :, ::-1],
                                             global_step=int((epoch + 1)),
                                             dataformats='HWC')

                            pred_label_img = pred.astype(np.uint8)[k]
                            pred_label_img_color = label_img_to_color(
                                pred_label_img)
                            writer.add_image('{}/mask'.format(name),
                                             pred_label_img_color[:, :, ::-1],
                                             global_step=int((epoch + 1)),
                                             dataformats='HWC')

                            softmax_pred = F.softmax(outputs[1][k], dim=0)
                            softmax_pred_np = softmax_pred.data.cpu().numpy()
                            probility = softmax_pred_np[0]
                            probility = probility * 255
                            probility = probility.astype(np.uint8)
                            probility = cv2.applyColorMap(
                                probility, cv2.COLORMAP_HOT)
                            writer.add_image('{}/prob'.format(name),
                                             cv2.cvtColor(
                                                 probility, cv2.COLOR_BGR2RGB),
                                             global_step=int((epoch + 1)),
                                             dataformats='HWC')
                            # 差分图
                            diff_img = np.ones((pred_label_img.shape[0],
                                                pred_label_img.shape[1]),
                                               dtype=np.int32) * 255
                            mask = (labels[k] != pred_label_img)
                            diff_img[mask] = labels[k][mask]
                            diff_img_color = diff_label_img_to_color(diff_img)
                            writer.add_image('{}/different_image'.format(name),
                                             diff_img_color[:, :, ::-1],
                                             global_step=int((epoch + 1)),
                                             dataformats='HWC')
                    if phase == 'train':
                        loss.backward()
                        global_step += 1
                        optimizer.step()
                        ema_optimizer.step()
                        adjust_learning_rate_poly(
                            args.lr, optimizer,
                            epoch * len(dataloaders['train']) + i,
                            args.num_epochs * len(dataloaders['train']))
                        lr = optimizer.param_groups[0]['lr']
                        writer.add_scalar(
                            'lr',
                            lr,
                            global_step=epoch * len(dataloaders['train']) + i)
                epoch_samples += 1
            epoch_loss = metrics['loss'] / epoch_samples
            ce_loss = metrics['ce_loss'] / epoch_samples
            ls_loss = metrics['ls_loss'] / epoch_samples
            miou = evaluator.Mean_Intersection_over_Union()
            AA = evaluator.Pixel_Accuracy_Class()
            OA = evaluator.Pixel_Accuracy()
            confusion_matrix = evaluator.confusion_matrix
            if phase == 'val':
                miou_mat = evaluator.Mean_Intersection_over_Union_test()
                writer.add_scalar('val/val_loss',
                                  epoch_loss,
                                  global_step=epoch)
                writer.add_scalar('val/ce_loss', ce_loss, global_step=epoch)
                writer.add_scalar('val/ls_loss', ls_loss, global_step=epoch)
                #writer.add_scalar('val/val_fwiou', fwiou, global_step=epoch)
                writer.add_scalar('val/val_miou', miou, global_step=epoch)
                for index in range(args.n_cls):
                    writer.add_scalar('class/{}'.format(index + 1),
                                      miou_mat[index],
                                      global_step=epoch)

                print(
                    '[val]------miou: {:4f}, OA:{:4f}, AA: {:4f}, loss: {:4f}'.
                    format(miou, OA, AA, epoch_loss))
                print(
                    '[val]------miou: {:4f}, OA:{:4f}, AA: {:4f}, loss: {:4f}'.
                    format(miou, OA, AA, epoch_loss),
                    file=F_txt,
                    flush=True)
            if phase == 'train':
                writer.add_scalar('train/train_loss',
                                  epoch_loss,
                                  global_step=epoch)
                writer.add_scalar('train/ce_loss', ce_loss, global_step=epoch)
                writer.add_scalar('train/ls_loss', ls_loss, global_step=epoch)
                #writer.add_scalar('train/train_fwiou', fwiou, global_step=epoch)
                writer.add_scalar('train/train_miou', miou, global_step=epoch)
                print(
                    '[train]------miou: {:4f}, OA: {:4f}, AA: {:4f}, loss: {:4f}'
                    .format(miou, OA, AA, epoch_loss))
                print(
                    '[train]------miou: {:4f}, OA: {:4f}, AA: {:4f}, loss: {:4f}'
                    .format(miou, OA, AA, epoch_loss),
                    file=F_txt,
                    flush=True)

            if phase == 'val' and miou > best_miou:
                print("\33[91msaving best model miou\33[0m")
                print("saving best model miou", file=F_txt, flush=True)
                best_miou = miou
                best_OA = OA
                best_AA = AA
                best_loss = epoch_loss
                torch.save(
                    {
                        'name': 'resnest50_lovasz_edge_rotate',
                        'epoch': epoch + 1,
                        'state_dict': ema_model.state_dict(),
                        'best_miou': best_miou
                    }, args.save_path + folder_path + '/model_best.pth')
                torch.save({
                    'optimizer': optimizer.state_dict(),
                }, args.save_path + folder_path + '/optimizer.pth')
        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60),
              file=F_txt,
              flush=True)

    print('[Best val]------miou: {:4f}; OA: {:4f}; AA: {:4f}; loss: {:4f}'.
          format(best_miou, best_OA, best_AA, best_loss))
    print('[Best val]------miou: {:4f}; OA: {:4f}; AA: {:4f}; loss: {:4f}'.
          format(best_miou, best_OA, best_AA, best_loss),
          file=F_txt,
          flush=True)
    F_txt.close()
Exemplo n.º 9
0
    '''
        Train (time < 2015)
    '''
    model.train()
    train_losses = []
    torch.cuda.empty_cache()
    for _ in range(args.repeat):
        for node_feature, node_type, edge_time, edge_index, edge_type, ylabel, node_time in train_data:
            node_rep = gnn.forward(node_feature, node_type.to(device), \
                                   edge_time.to(device), edge_index.to(device), edge_type.to(device))
            res = classifier.forward(node_rep)

            y = ylabel['img'].to(device)
            f1 = node_rep[node_type==0]
            f2 = node_rep[node_type==1]
            loss = calc_loss(f1, f2, res[node_type==0], res[node_type==1], y, y, 1, 1)

            optimizer.zero_grad()
            torch.cuda.empty_cache()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()

            train_losses += [loss.cpu().detach().tolist()]
            train_step += 1
            scheduler.step(train_step)
            del res, loss


    '''
Exemplo n.º 10
0
def train_model(model,
                optimizer,
                scheduler,
                device,
                num_epochs,
                dataloaders,
                info,
                args,
                fine_tune=False):
    """
        Train the model

        Args:
            model: A neural netowrk model for training
            optimizer: A optimizer to calculate gradients
            scheduler: A scheduler to change a learning rate
            device: gpu or cpu
            num_epochs: a number of epochs
            dataloaders: a data loader
            info: a dictionary to save metrics
            fine_tune: If True, it saved metrics of the fine tuning phase

        Return:
            model: A trained model
            metric_train: Metrics from training phase
            metric_valid: Metrics from validation phase
    """
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10

    # Initialize list to save loss from train and validation phase
    epoch_train_loss = list()
    epoch_valid_loss = list()
    epoch_train_dice_loss = list()
    epoch_valid_dice_loss = list()
    epoch_train_bce = list()
    epoch_valid_bce = list()

    # Initialize SummaryWriter to visualize losses on Tensorboard
    writer = SummaryWriter()

    # Initialize EarlyStopping
    early_stopping = EarlyStopping(patience=args.earlystop, verbose=True)

    # Training starts
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)
        since = time.time()
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            # initialize metric dict to save losses for each epoch
            metrics = defaultdict(float)
            epoch_samples = 0

            # Load a batch of images and labels
            for images, labels in dataloaders[phase]:
                images = images.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(images)
                    loss = calc_loss(outputs, labels, metrics)
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                epoch_samples += images.size(0)

            # print metrics
            print_metrics(metrics, epoch_samples, phase)

            # save the loss of the current epoch
            epoch_loss = metrics['loss'] / epoch_samples

            if phase == 'train':
                # save training metrics for tensorboard
                writer.add_scalar('Loss(BCE+Dice)/train', epoch_loss, epoch)
                writer.add_scalar('Dice Loss/train',
                                  metrics['dice_loss'] / epoch_samples, epoch)
                writer.add_scalar('BCE/train', metrics['bce'] / epoch_samples,
                                  epoch)

                # save training metrics for later use ;)
                epoch_train_loss.append(metrics['loss'] / epoch_samples)
                epoch_train_bce.append(metrics['bce'] / epoch_samples)
                epoch_train_dice_loss.append(metrics['dice_loss'] /
                                             epoch_samples)

            elif phase == 'val':
                # save validation metrics for tensorboard
                writer.add_scalar('learning_rate',
                                  optimizer.param_groups[0]['lr'],
                                  epoch)  # to plot LR reduction
                writer.add_scalar('Loss(BCE+Dice)/valid',
                                  metrics['loss'] / epoch_samples, epoch)
                writer.add_scalar('Dice Loss/valid',
                                  metrics['dice_loss'] / epoch_samples, epoch)
                writer.add_scalar('BCE/valid', metrics['bce'] / epoch_samples,
                                  epoch)

                # save validation metrics for later use
                epoch_valid_loss.append(metrics['loss'] / epoch_samples)
                epoch_valid_bce.append(metrics['bce'] / epoch_samples)
                epoch_valid_dice_loss.append(metrics['dice_loss'] /
                                             epoch_samples)

                scheduler.step(
                    epoch_loss)  # pass loss to ReduceLROnPlateau scheduler

                early_stopping(epoch_loss, model, optimizer,
                               args)  #  evaluate early stopping criterion

                # compare loss and deep copy the model
                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())

        time_elapsed = time.time() - since  # compute time of epoch
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

        # check early stop is True or not
        if early_stopping.early_stop:
            print(f"Early stopping after epoch {epoch}")
            if fine_tune == False:
                info['early stop'] = 'True'
                info['stopping LR'] = optimizer.param_groups[0]['lr']
                info['stopping epoch'] = epoch + 1
                info['best loss'] = best_loss
            else:
                info['fine_tune_early stop'] = 'True'
                info['fine_tune_stopping LR'] = optimizer.param_groups[0]['lr']
                info['fine_tune_stopping epoch'] = epoch + 1
                info['fine_tune_best loss'] = best_loss
            break

    # check early stop is True or not
    if early_stopping.early_stop != True:
        if fine_tune == False:
            info['early stop'] = 'False'
            info['stopping LR'] = optimizer.param_groups[0]['lr']
            info['stopping epoch'] = num_epochs
            info['best loss'] = best_loss
        else:
            info['fine_tune_early stop'] = 'False'
            info['fine_tune_stopping LR'] = optimizer.param_groups[0]['lr']
            info['fine_tune_stopping epoch'] = num_epochs
            info['fine_tune_best loss'] = best_loss

    print('Best val loss: {:4f}'.format(best_loss))

    # collect all metrics
    metric_train = (epoch_train_loss, epoch_train_bce, epoch_train_dice_loss)
    metric_valid = (epoch_valid_loss, epoch_valid_bce, epoch_valid_dice_loss)

    # load best model weights (necessary for fine tuning of ResNet-UNet)
    model.load_state_dict(best_model_wts)

    writer.close()  # end tensorboard writing

    return model, metric_train, metric_valid
Exemplo n.º 11
0
def test_bdclstm(bdclstm, dataloaders, root_dir, test_subjects, device,
                 save_dir):
    bdclstm.eval()

    GT_dir = os.path.join(root_dir, 'npy_labels')
    subject_seg = set_subject_result(GT_dir, test_subjects)

    metrics = defaultdict(float)
    epoch_samples = 0
    # cnt = 0
    with torch.no_grad():
        for inputs, labels, subject, name in tqdm(dataloaders):
            subject = subject[0]
            name = name[0].replace('.npy', '')

            labels = labels.to(device)
            inputs = inputs.to(device)

            outputs = bdclstm(inputs)

            thresh_outputs = F.sigmoid(outputs)
            thresh_outputs[thresh_outputs >= 0.5] = 1.0
            thresh_outputs[thresh_outputs < 0.5] = 0.0

            loss, dice = calc_loss(outputs, labels, metrics)

            pred = np.squeeze(thresh_outputs.cpu().data[0].numpy()).astype(
                np.float16)

            p_z, p_x, p_y = pred.shape
            _, z, x, y = name.split('_')
            z, x, y = int(z), int(x), int(y)
            subject_seg[subject][
                z:z + p_z, x:x + p_x, y:y +
                p_y] = (subject_seg[subject][z:z + p_z, x:x + p_x, y:y + p_y] +
                        pred) / 2

            epoch_samples += inputs.size(0)
        print_metrics(metrics, epoch_samples, 'test')
    source_dir = os.path.join(root_dir, 'Labels')
    for key in subject_seg:
        original_path = os.path.join(source_dir, key)

        origin_label = os.path.join(source_dir, 'label{}.nii.gz'.format(key))
        origin_3D = sitk.ReadImage(origin_label, sitk.sitkInt16)

        subject_dir = make_dir(save_dir, key)
        save_path = os.path.join(subject_dir, 'pred.nii')

        subject_seg[key][subject_seg[key] >= 0.4] = 1
        subject_seg[key][subject_seg[key] < 0.4] = 0

        result_3D = sitk.GetImageFromArray(subject_seg[key].astype(np.uint8))

        result_3D.SetDirection(origin_3D.GetDirection())
        result_3D.SetOrigin(origin_3D.GetOrigin())
        result_3D.SetSpacing(origin_3D.GetSpacing())

        sitk.WriteImage(result_3D, save_path)
        del result_3D

    return bdclstm
Exemplo n.º 12
0
def main(args):
    #******************************************** load args
    batch_size = args.batch_size
    data_record_dir = args.tfrecord_dir
    log_dir = args.log_dir
    sfam_fg = args.sfam
    box_loss_scale = args.scale
    model_dir = args.model_dir
    load_num = args.load_num
    epoches = args.epoches
    save_weight_period = args.save_weight_period
    #********************************************creat logging
    logger = logging.getLogger()
    hdlr = logging.FileHandler(os.path.join(log_dir,time.strftime('%F-%T',time.localtime()).replace(':','-')+'.log'))
    #formatter = logging.Formatter('[%(asctime)s] [%(levelname)s] [%(threadName)-10s] %(message)s')
    #hdlr.setFormatter(formatter)
    logger.addHandler(hdlr)
    logger.addHandler(logging.StreamHandler())
    logger.setLevel(logging.INFO)
    #********************************************load data
    y_true_size = cfgs.ClsNum + 6
    with tf.name_scope("Load_data"):
        '''
        databox = Data(image_dir=args.image_dir, 
                    label_dir=args.label_dir, 
                    assignment_threshold=args.assignment_threshold)
        databox.start()
        dataset_size = databox.size
        logger.info('Dataset size: {}'.format(dataset_size))
        '''
        '''
        y_true_size = 4 + num_classes + 1 + 1 = num_classes + 6
        * 4 => bbox coordinates (x1, y1, x2, y2);
        * num_classes + 1 => including a background class;
        * 1 => denotes if the prior box was matched to some gt boxes or not;
        '''
        tfrd = Read_Tfrecord(cfgs.DataSet_Name,data_record_dir,batch_size,1000)
        num_obj_batch, img_batch, gtboxes_label_batch = tfrd.next_batch()
        anchors = tf.py_func(generate_anchors,[],tf.float32)
        anchors.set_shape([None,4])
        x_batch,y_true = tf.py_func(process_imgs,[img_batch,gtboxes_label_batch,num_obj_batch,batch_size,anchors],[tf.float32,tf.float32])
        x_batch.set_shape([None,cfgs.ImgSize,cfgs.ImgSize,3])
        y_true.set_shape([None,cfgs.AnchorBoxes,y_true_size])
    #********************************************build network
    #inputs = tf.placeholder(tf.float32, [None, cfgs.ImgSize, cfgs.ImgSize, 3])
    #y_true = tf.placeholder(tf.float32, [None, cfgs.AnchorBoxes, y_true_size])
    is_training = tf.constant(True)
    net = M2Det(x_batch, is_training, sfam_fg)
    y_pred = net.prediction
    with tf.name_scope("Losses"):
        total_loss,bbox_loss,class_loss = calc_loss(y_true, y_pred, box_loss_weight=box_loss_scale)
        ave_box_loss = tf.reduce_mean(bbox_loss)
        ave_clss_loss = tf.reduce_mean(class_loss)
        tf.summary.scalar('cls_box/cb_loss', total_loss)
        weights = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if 'kernel' in v.name]
        decay = tf.reduce_sum(tf.stack([tf.nn.l2_loss(w) for w in weights])) * 5e-4
        tf.summary.scalar('weight/weight_loss', decay)
        total_loss += decay
        tf.summary.scalar('total/total_loss', total_loss)
        tf.summary.scalar('cls_box/ave_box_loss',ave_box_loss)
        tf.summary.scalar('cls_box/ave_class_loss',ave_clss_loss)
        tf.summary.histogram('bbox/box_loss',bbox_loss)
        tf.summary.histogram('class/class_loss',class_loss)
    #***************************************************************build trainer
    with tf.name_scope("optimizer"):
        global_step = tf.train.get_or_create_global_step()
        lr = tf.train.piecewise_constant(global_step,
                                        boundaries=[np.int64(x) for x in cfgs.DECAY_STEP],
                                        values=[y for y in cfgs.LR])
        tf.summary.scalar('lr', lr)
        #optimizer = tf.train.MomentumOptimizer(lr, momentum=cfgs.MOMENTUM)
        #global_step = tf.Variable(0, name='global_step', trainable=False)
        train_var = tf.trainable_variables()
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            #opt = tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.9)
            opt = tf.train.AdamOptimizer(learning_rate=lr)
            grads = tf.gradients(total_loss, train_var)
            #tf.summary.histogram("Grads/")
            train_op = opt.apply_gradients(zip(grads, train_var), global_step=global_step)
    #***********************************************************************************training
    with tf.name_scope("training_op"):
        tf_config = tf.ConfigProto()
        #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7)
        #tf_config.gpu_options = gpu_options
        tf_config.gpu_options.allow_growth=True  
        tf_config.log_device_placement=False
        sess = tf.Session(config=tf_config)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord)
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        saver = tf.train.Saver(max_to_keep=10)
        #load model
        model_path = os.path.join(model_dir,cfgs.DataSet_Name)
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        model_path = os.path.join(model_path,cfgs.ModelPrefix)
        if load_num is not None :
            #assert tf.train.get_checkpoint_state(model_dir),'the params dictionary is not valid'
            model_path = "%s-%s" %(model_path,load_num)
            saver.restore(sess, model_path)
            logger.info('Resuming training %s' % model_path)
        # build summary
        summary_op = tf.summary.merge_all()
        summary_path = os.path.join(log_dir,'summary')
        if not os.path.exists(summary_path):
            os.makedirs(summary_path)
        summary_writer = tf.summary.FileWriter(summary_path, graph=sess.graph)
        # begin to tain
        try:
            for epoch_tmp in range(epoches):
                for step in range(np.ceil(cfgs.Train_Num/batch_size).astype(np.int32)):
                    training_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
                    global_value = sess.run(global_step)
                    if step % cfgs.SHOW_TRAIN_INFO !=0 and step % cfgs.SMRY_ITER !=0:
                        _ = sess.run([train_op])
                    else:
                        if step % cfgs.SHOW_TRAIN_INFO ==0:
                            _, loss_value,cur_lr,ave_box,ave_cls = sess.run([train_op, total_loss,opt._lr,ave_box_loss,ave_clss_loss])
                            logger.info('{} \t epoch:{}, lr:{}, step: {}, loss: {} , bbox:{}, cls:{}'.format(str(training_time),epoch_tmp,cur_lr,global_value, loss_value,ave_box,ave_cls))
                        if step % cfgs.SMRY_ITER ==0:
                            _, summary_str = sess.run([train_op,summary_op])
                            summary_writer.add_summary(summary_str,global_value)
                            summary_writer.flush()
                if (epoch_tmp > 0 and epoch_tmp % save_weight_period == 0) or (epoch_tmp == epoches - 1):
                    dst = model_path
                    saver.save(sess, dst, epoch_tmp,write_meta_graph=False)
                    logger.info(">>*************** save weight ***: %d" % epoch_tmp)
        except tf.errors.OutOfRangeError:
            print("Trianing is error")
        finally:
            coord.request_stop()
            summary_writer.close()
            coord.join(threads)
            #record_file_out.close()
            sess.close()
    '''
        Train (time < 2015)
    '''
    model.train()
    train_losses = []
    torch.cuda.empty_cache()
    for _ in range(args.repeat):
        for node_feature, node_type, edge_time, edge_index, edge_type, ylabel, node_time in train_data:
            node_rep = gnn.forward(node_feature, node_type.to(device), \
                                   edge_time.to(device), edge_index.to(device), edge_type.to(device))
            res = classifier.forward(node_rep)

            y = ylabel['img'].to(device)
            f1 = node_rep[node_type==0]
            f2 = node_rep[node_type==1]
            loss = calc_loss(f1, f2, res[node_type==0], res[node_type==1], y, y, alpha, beta)

            optimizer.zero_grad()
            torch.cuda.empty_cache()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()

            train_losses += [loss.cpu().detach().tolist()]
            train_step += 1
            scheduler.step(train_step)
            del res, loss


    '''
def train_bdclstm(bdclstm,
                  dataloaders,
                  optimizer,
                  scheduler,
                  device,
                  bce_weight,
                  final_model_path,
                  num_epochs=40):
    best_model_wts = copy.deepcopy(bdclstm.state_dict())
    best_loss = 1

    l1_crit = nn.L1Loss(size_average=False)
    factor = 0.0005

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print('LR', param_group['lr'])
                bdclstm.train()
            else:
                bdclstm.eval()

            metrics = defaultdict(float)
            epoch_samples = 0
            # cnt = 0
            for inputs, labels, _, _ in tqdm(dataloaders[phase]):
                feature_list = []
                # labels = labels[:, 0, :, :, :, :].to(bdclstm_device)
                labels = labels.to(device)
                inputs = inputs.to(device)

                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = bdclstm(inputs)
                    loss, _ = calc_loss(outputs,
                                        labels,
                                        metrics,
                                        bce_weight=bce_weight)
                    reg_loss = 0
                    # for param in bdclstm.parameters():
                    # 	reg_loss += l1_crit(param, torch.zeros_like(param))
                    # loss += factor * reg_loss
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                epoch_samples += inputs.size(0)
            print_metrics(metrics, epoch_samples, phase)

            epoch_loss = metrics['dice'] / epoch_samples
            if phase == 'val' and epoch_loss < best_loss:
                print('best loss changed!')
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(bdclstm.state_dict())

        print('svae temproal model!')
        torch.save(best_model_wts, final_model_path)
        # 	torch.save(model.state_dict(), final_model_path)

    bdclstm.load_state_dict(best_model_wts)
    return bdclstm