Ejemplo n.º 1
0
def run():
    """创建模型"""
    model = FFN(in_channels=2, out_channels=1).cuda()
    """数据路径"""
    input_h5data = [args.data]
    """创建data loader"""
    train_dataset = BatchCreator(input_h5data,
                                 args.input_size,
                                 delta=args.delta,
                                 train=True)
    train_loader = DataLoader(train_dataset,
                              shuffle=True,
                              num_workers=1,
                              pin_memory=True)

    optimizer = optim.SGD(
        model.parameters(),
        lr=1e-3,  # Learning rate is set by the lr_sched below
        momentum=0.9,
        weight_decay=0.5e-4,
    )

    best_loss = np.inf
    """获取数据流"""
    for itr, (seeds, images, labels, offsets) in enumerate(
            get_batch(train_loader, args.batch_size, args.input_size,
                      partial(fixed_offsets, fov_moves=train_dataset.shifts))):

        input_data = torch.cat([images, seeds], dim=1)

        input_data = Variable(input_data.cuda())
        seeds = seeds.cuda()
        labels = labels.cuda()

        logits = model(input_data)

        updated = seeds + logits
        optimizer.zero_grad()
        loss = F.binary_cross_entropy_with_logits(updated, labels)
        loss.backward()
        """梯度截断"""
        torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad_thr)

        optimizer.step()

        diff = (updated.sigmoid() - labels).detach().cpu().numpy()
        accuracy = 1.0 * (diff < 0.001).sum() / np.prod(labels.shape)
        print("loss: {}, offset: {}, Accuracy: {:.2f}% ".format(
            loss.item(), itr,
            accuracy.item() * 100))

        # update_seed(updated, seeds, model, offsets)
        # seed = updated
        """根据最佳loss并且保存模型"""
        if best_loss > loss.item():
            best_loss = loss.item()
            torch.save(model.state_dict(),
                       os.path.join(args.save_path, 'ffn.pth'))
            print('Model saved!')
Ejemplo n.º 2
0
def run():

    cudnn.benchmark = True
    torch.cuda.set_device(args.local_rank)
    
    # will read env master_addr master_port world_size
    torch.distributed.init_process_group(backend='nccl', init_method="env://")
    args.world_size = dist.get_world_size()
    args.rank = dist.get_rank()
    # args.local_rank = int(os.environ.get('LOCALRANK', args.local_rank))
    args.total_batch_size = (args.batch_size) * dist.get_world_size()
    
    global resume_iter
    """model_log"""
    input_size_r = list(args.input_size)
    delta_r = list(args.delta)

    path = args.log_save_path + "model_log_fov:{}_delta:{}_depth:{}".format(input_size_r [0],delta_r[0],args.depth)
    filesize = os.path.getsize(path)
    if filesize == 0:

        f = open(path, 'wb')
        data_start = {'chris': "xtx"}
        pickle.dump(data_start, f)
        f.close()
    else:
        f = open(path, 'rb')
        data = pickle.load(f)
        resume_iter = len(data.keys())-1
        f.close()


    """model_construction"""
    model = FFN(in_channels=4, out_channels=1, input_size=args.input_size, delta=args.delta, depth=args.depth).cuda()
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)


    """data_load"""
    if args.resume is not None:
        model.load_state_dict(torch.load(args.resume))


    abs_path_training_data = args.train_data_dir
    entries_train_data = Path(abs_path_training_data )
    files_train_data = []

    for entry in entries_train_data.iterdir():
        files_train_data.append(entry.name)

    sorted_files_train_data = natsort.natsorted(files_train_data, reverse=False)

    files_total = len(sorted_files_train_data)

    input_h5data_dict = {}
    train_dataset_dict = {}
    train_loader_dict = {}
    batch_it_dict = {}
    train_sampler_dict = {}

    for index in range(files_total):
        input_h5data_dict[index] = [(abs_path_training_data + sorted_files_train_data[index])]
        print(input_h5data_dict[index])
        train_dataset_dict[index] = BatchCreator(input_h5data_dict[index], args.input_size, delta=args.delta, train=True)
        train_sampler_dict[index] = torch.utils.data.distributed.DistributedSampler(train_dataset_dict[index], num_replicas=args.world_size, rank=args.rank, shuffle=True)
        train_loader_dict[index] = DataLoader(train_dataset_dict[index], num_workers=0, sampler=train_sampler_dict[index] , pin_memory=True)
        batch_it_dict[index] = get_batch(train_loader_dict[index], args.batch_size, args.input_size,
                               partial(fixed_offsets, fov_moves=train_dataset_dict[index].shifts))




    best_loss = np.inf

    """optimizer"""
    t_last = time.time()
    cnt = 0
    tp = fp = tn = fn = 0
    
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    #optimizer = optim.SGD(model.parameters(), lr=1e-3) 
    #momentum=0.9 
    #optimizer = adabound.AdaBound(model.parameters(), lr=1e-3, final_lr=0.1)
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.step, gamma=args.gamma, last_epoch=-1)


    """train_loop"""
    while cnt < args.iter:
        cnt += 1

        Num_of_train_data = len(input_h5data_dict)
        index_rand = random.randrange(0, Num_of_train_data, 1)

        seeds, images, labels, offsets = next(batch_it_dict[index_rand])
        #print(sorted_files_train_data[index_rand])
        #seeds = seeds.cuda(non_blocking=True)
        images = images.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        #offsets = offsets.cuda(non_blocking=True)

        t_curr = time.time()

        labels = labels.cuda(non_blocking=True)

        torch_seed = torch.from_numpy(seeds).cuda(non_blocking=True)
        input_data = torch.cat([images, torch_seed], dim=1)
        input_data = Variable(input_data.cuda(non_blocking=True))

        logits = model(input_data)
        updated = torch_seed + logits

        optimizer.zero_grad()
        loss = F.binary_cross_entropy_with_logits(updated, labels)
        loss.backward()

        torch.nn.utils.clip_grad_value_(model.parameters(), args.clip_grad_thr)
        optimizer.step()
        
        
        seeds[...] = updated.detach().cpu().numpy()

        pred_mask = (updated >= logit(0.9)).detach().cpu().numpy()
        true_mask = (labels > 0.5).cpu().numpy()
        true_bg = np.logical_not(true_mask)
        pred_bg = np.logical_not(pred_mask)
        tp += (true_mask & pred_mask).sum()
        fp += (true_bg & pred_mask).sum()
        fn += (true_mask & pred_bg).sum()
        tn += (true_bg & pred_bg).sum()
        precision = 1.0 * tp / max(tp + fp, 1)
        recall = 1.0 * tp / max(tp + fn, 1)
        accuracy = 1.0 * (tp + tn) / (tp + tn + fp + fn)
        if args.rank == 0:
            print('[Iter_{}:, loss: {:.4}, Precision: {:.2f}%, Recall: {:.2f}%, Accuracy: {:.2f}%]\r'.format(
            cnt, loss.item(), precision*100, recall*100, accuracy * 100))

        #scheduler.step()


        """model_saving_(best_loss)"""
        """
        if best_loss > loss.item() or t_curr - t_last > args.interval:
            tp = fp = tn = fn = 0
            t_last = t_curr
            best_loss = loss.item()
            input_size_r = list(args.input_size)
            delta_r = list(args.delta)
            torch.save(model.state_dict(), os.path.join(args.save_path,
                                                        'ffn_model_fov:{}_delta:{}_depth:{}.pth'.format(input_size_r[0],
                                                                                                        delta_r[0],
                                                                                                        args.depth)))
            print('Precision: {:.2f}%, Recall: {:.2f}%, Accuracy: {:.2f}%, Model saved!'.format(
                precision * 100, recall * 100, accuracy * 100))
        """

        """model_saving_(iter)"""


        if (cnt % args.save_interval) == 0 and args.rank == 0:
            tp = fp = tn = fn = 0
            #t_last = t_curr
            #best_loss = loss.item()
            input_size_r = list(args.input_size)
            delta_r = list(args.delta)
            torch.save(model.state_dict(), os.path.join(args.save_path, (str(args.stream) + 'ffn_model_fov:{}_delta:{}_depth:{}_recall{}.pth'.format(input_size_r [0],delta_r[0],args.depth,recall*100))))
            print('Precision: {:.2f}%, Recall: {:.2f}%, Accuracy: {:.2f}%, Model saved!'.format(
                precision * 100, recall * 100, accuracy * 100))


            path = args.log_save_path + "model_log_fov:{}_delta:{}_depth:{}".format(input_size_r [0],delta_r[0],args.depth)
            model_eval = "precision#" + str('%.4f' % (precision * 100)) + "#recall#" + str('%.4f' % (recall * 100)) + "#accuracy#" + str('%.4f' % (accuracy * 100))

            f_l = open(path, 'rb')
            data = pickle.load(f_l)

            key =  cnt/args.save_interval + resume_iter
            data[key] = model_eval

            f_o = open(path, 'wb')
            pickle.dump(data, f_o)

            f_o.close()
            f_l.close()
Ejemplo n.º 3
0
def run():
    """创建模型"""
    model = FFN(in_channels=4,
                out_channels=1,
                input_size=args.input_size,
                delta=args.delta,
                depth=args.depth).cuda()
    if args.resume is not None:
        model.load_state_dict(torch.load(args.resume))
    """数据路径"""
    input_h5data = [args.data]
    """创建data loader"""
    train_dataset = BatchCreator(input_h5data,
                                 args.input_size,
                                 delta=args.delta,
                                 train=True)
    train_loader = DataLoader(train_dataset,
                              shuffle=True,
                              num_workers=0,
                              pin_memory=True)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                args.step,
                                                gamma=args.gamma,
                                                last_epoch=-1)
    criterion = torch.nn.CrossEntropyLoss()
    best_loss = np.inf
    """获取数据流"""
    t_last = time.time()
    cnt = 0
    tp = fp = tn = fn = 0
    batch_it = get_batch(
        train_loader, args.batch_size, args.input_size,
        partial(fixed_offsets, fov_moves=train_dataset.shifts))
    while cnt < args.iter:
        cnt += 1
        seeds, images, labels, offsets = next(batch_it)
        t_curr = time.time()
        """正样本权重"""
        pos_w = torch.tensor([1]).float().cuda()
        #slice = sigmoid(seeds[:, :, seeds.shape[2] // 2, :, :])
        #seeds[:, :, seeds.shape[2] // 2, :, :] = slice
        labels = labels.cuda()
        torch_seed = torch.from_numpy(seeds)
        input_data = torch.cat([images, torch_seed], dim=1)
        input_data = input_data.cuda()
        out = model(input_data)
        updated = torch_seed.cuda() + out
        optimizer.zero_grad()
        loss = criterion(updated, labels)
        loss.backward()
        """梯度截断"""
        torch.nn.utils.clip_grad_value_(model.parameters(), args.clip_grad_thr)
        optimizer.step()
        seeds[...] = updated.detach().cpu().numpy()
        pred_mask = (updated >= logit(0.9)).detach().cpu().numpy()
        true_mask = (labels > 0.5).cpu().numpy()
        true_bg = np.logical_not(true_mask)
        pred_bg = np.logical_not(pred_mask)
        tp += (true_mask & pred_mask).sum()
        fp += (true_bg & pred_mask).sum()
        fn += (true_mask & pred_bg).sum()
        tn += (true_bg & pred_bg).sum()
        precision = 1.0 * tp / max(tp + fp, 1)
        recall = 1.0 * tp / max(tp + fn, 1)
        accuracy = 1.0 * (tp + tn) / (tp + tn + fp + fn)
        print(
            '[Iter_{}:, loss: {:.4}, Precision: {:.2f}%, Recall: {:.2f}%, Accuracy: {:.2f}%]\r'
            .format(cnt, loss.item(), precision * 100, recall * 100,
                    accuracy * 100))
        scheduler.step()
        """根据最佳loss并且保存模型"""
        if best_loss > loss.item() or t_curr - t_last > args.interval:
            tp = fp = tn = fn = 0
            t_last = t_curr
            best_loss = loss.item()
            torch.save(model.state_dict(),
                       os.path.join(args.save_path, 'ffn.pth'))
            print(
                'Precision: {:.2f}%, Recall: {:.2f}%, Accuracy: {:.2f}%, Model saved!'
                .format(precision * 100, recall * 100, accuracy * 100))
Ejemplo n.º 4
0
def run():


    """model_construction"""
    model = FFN(in_channels=4, out_channels=1, input_size=args.input_size, delta=args.delta, depth=args.depth).cuda()



    """data_load"""
    if args.resume is not None:
        model.load_state_dict(torch.load(args.resume))


    abs_path_training_data = args.train_data_dir
    entries_train_data = Path(abs_path_training_data )
    files_train_data = []

    for entry in entries_train_data.iterdir():
        files_train_data.append(entry.name)

    sorted_files_train_data = natsort.natsorted(files_train_data, reverse=False)

    files_total = len(sorted_files_train_data)

    input_h5data_dict = {}
    train_dataset_dict = {}
    train_loader_dict = {}
    batch_it_dict = {}

    for index in range(files_total):
        input_h5data_dict[index] = [(abs_path_training_data + sorted_files_train_data[index])]
        train_dataset_dict[index] = BatchCreator(input_h5data_dict[index], args.input_size, delta=args.delta, train=True)
        train_loader_dict[index] = DataLoader(train_dataset_dict[index], shuffle=True, num_workers=0, pin_memory=True)
        batch_it_dict[index] = get_batch(train_loader_dict[index], args.batch_size, args.input_size,
                               partial(fixed_offsets, fov_moves=train_dataset_dict[index].shifts))




    best_loss = np.inf

    """optimizer"""
    t_last = time.time()
    cnt = 0
    tp = fp = tn = fn = 0
    
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    #optimizer = optim.SGD(model.parameters(), lr=1e-3) 
    #momentum=0.9 
    #optimizer = adabound.AdaBound(model.parameters(), lr=1e-3, final_lr=0.1)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.step, gamma=args.gamma, last_epoch=-1)


    """train_loop"""
    while cnt < args.iter:
        cnt += 1

        Num_of_train_data = len(input_h5data_dict)
        index_rand = random.randrange(0, Num_of_train_data, 1)

        seeds, images, labels, offsets = next(batch_it_dict[index_rand])
        #print(sorted_files_train_data[index_rand])

        t_curr = time.time()

        labels = labels.cuda()

        torch_seed = torch.from_numpy(seeds)
        input_data = torch.cat([images, torch_seed], dim=1)
        input_data = Variable(input_data.cuda())

        logits = model(input_data)
        updated = torch_seed.cuda() + logits

        optimizer.zero_grad()
        loss = F.binary_cross_entropy_with_logits(updated, labels)
        loss.backward()

        #torch.nn.utils.clip_grad_value_(model.parameters(), args.clip_grad_thr)
        optimizer.step()
        
        
        seeds[...] = updated.detach().cpu().numpy()

        pred_mask = (updated >= logit(0.9)).detach().cpu().numpy()
        true_mask = (labels > 0.5).cpu().numpy()
        true_bg = np.logical_not(true_mask)
        pred_bg = np.logical_not(pred_mask)
        tp += (true_mask & pred_mask).sum()
        fp += (true_bg & pred_mask).sum()
        fn += (true_mask & pred_bg).sum()
        tn += (true_bg & pred_bg).sum()
        precision = 1.0 * tp / max(tp + fp, 1)
        recall = 1.0 * tp / max(tp + fn, 1)
        accuracy = 1.0 * (tp + tn) / (tp + tn + fp + fn)
        print('[Iter_{}:, loss: {:.4}, Precision: {:.2f}%, Recall: {:.2f}%, Accuracy: {:.2f}%]\r'.format(
            cnt, loss.item(), precision*100, recall*100, accuracy * 100))

        scheduler.step()


        """model_saving_(best_loss)"""
        """
        if best_loss > loss.item() or t_curr - t_last > args.interval:
            tp = fp = tn = fn = 0
            t_last = t_curr
            best_loss = loss.item()
            input_size_r = list(args.input_size)
            delta_r = list(args.delta)
            torch.save(model.state_dict(), os.path.join(args.save_path,
                                                        'ffn_model_fov:{}_delta:{}_depth:{}.pth'.format(input_size_r[0],
                                                                                                        delta_r[0],
                                                                                                        args.depth)))
            print('Precision: {:.2f}%, Recall: {:.2f}%, Accuracy: {:.2f}%, Model saved!'.format(
                precision * 100, recall * 100, accuracy * 100))
        """

        """model_saving_(iter)"""


        if (cnt % args.save_interval) == 0:
            tp = fp = tn = fn = 0
            #t_last = t_curr
            #best_loss = loss.item()
            input_size_r = list(args.input_size)
            delta_r = list(args.delta)
            torch.save(model.state_dict(), os.path.join(args.save_path, (str(args.stream) + 'ffn_model_fov:{}_delta:{}_depth:{}_recall{}.pth'.format(input_size_r [0],delta_r[0],args.depth,recall*100))))
            print('Precision: {:.2f}%, Recall: {:.2f}%, Accuracy: {:.2f}%, Model saved!'.format(
                precision * 100, recall * 100, accuracy * 100))
Ejemplo n.º 5
0
def run():
    """model init"""
    model = FFN(in_channels=4,
                out_channels=1,
                input_size=args.input_size,
                delta=args.delta,
                depth=args.depth).cuda()
    """model resume"""
    if args.resume is not None:
        model.load_state_dict(torch.load(args.resume))

    if os.path.exists(args.save_path + 'resume_step.pkl'):
        resume = load_obj(args.save_path + 'resume_step.pkl')
    else:
        resume = {'resume_step': args.resume_step}
    args.resume_step = resume['resume_step']
    print('resume_step', args.resume_step)

    if args.tb is None:
        tb = SummaryWriter('./tensorboard/' + args.tag +
                           'tb_train_log_fov:{}_delta:{}_depth:{}.pth'.format(
                               list(args.input_size)[0],
                               list(args.delta)[0], args.depth))
    else:
        tb = SummaryWriter(args.tb)

    sorted_files_train_data = sort_files(args.train_data_dir)
    files_total = len(sorted_files_train_data)
    input_h5data_dict = {}
    train_dataset_dict = {}
    train_loader_dict = {}
    batch_it_dict = {}
    for index in range(files_total):
        input_h5data_dict[index] = [
            (args.train_data_dir + sorted_files_train_data[index])
        ]
        train_dataset_dict[index] = BatchCreator(input_h5data_dict[index],
                                                 args.input_size,
                                                 delta=args.delta,
                                                 train=True)
        train_loader_dict[index] = DataLoader(train_dataset_dict[index],
                                              shuffle=True,
                                              num_workers=0,
                                              pin_memory=True)
        batch_it_dict[index] = get_batch(
            train_loader_dict[index], args.batch_size, args.input_size,
            partial(fixed_offsets, fov_moves=train_dataset_dict[index].shifts))
    """optimizer"""
    if args.opt == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    else:
        optimizer = optim.SGD(model.parameters(), lr=1e-3)
    # optimizer = adabound.AdaBound(model.parameters(), lr=1e-3, final_lr=0.1)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.step, gamma=args.gamma, last_epoch=-1)
    """train_loop"""
    t_last = time.time()
    cnt = 0
    tp = fp = tn = fn = 0
    best_loss = np.inf
    while cnt < args.iter:
        cnt += 1

        # record training iter every 1000
        if cnt % 1000 == 0:
            resume['resume_step'] = cnt + args.resume_step
            pickle_obj(resume, 'resume_step', args.save_path)

        # load training data (random)
        train_num = len(input_h5data_dict)
        index_rand = random.randrange(0, train_num, 1)
        seeds, images, labels, offsets = next(batch_it_dict[index_rand])
        print(input_h5data_dict[index_rand])

        t_curr = time.time()
        labels = labels.cuda()
        torch_seed = torch.from_numpy(seeds)
        input_data = torch.cat([images, torch_seed], dim=1)
        input_data = Variable(input_data.cuda())

        logits = model(input_data)
        updated = torch_seed.cuda() + logits

        optimizer.zero_grad()
        loss = F.binary_cross_entropy_with_logits(updated, labels)
        loss.backward()
        #torch.nn.utils.clip_grad_value_(model.parameters(), args.clip_grad_thr)
        optimizer.step()

        seeds[...] = updated.detach().cpu().numpy()

        pred_mask = (updated >= logit(0.8)).detach().cpu().numpy()
        true_mask = (labels > 0.5).cpu().numpy()
        true_bg = np.logical_not(true_mask)
        pred_bg = np.logical_not(pred_mask)
        tp += (true_mask & pred_mask).sum()
        fp += (true_bg & pred_mask).sum()
        fn += (true_mask & pred_bg).sum()
        tn += (true_bg & pred_bg).sum()
        precision = 1.0 * tp / max(tp + fp, 1)
        recall = 1.0 * tp / max(tp + fn, 1)
        accuracy = 1.0 * (tp + tn) / (tp + tn + fp + fn)
        print(
            '[Iter_{}:, loss: {:.4}, Precision: {:.2f}%, Recall: {:.2f}%, Accuracy: {:.2f}%]\r'
            .format(cnt, loss.item(), precision * 100, recall * 100,
                    accuracy * 100))

        # scheduler.step()
        """model_saving_(iter)"""

        if (cnt % args.save_interval) == 0:
            tp = fp = tn = fn = 0
            # t_last = t_curr
            # best_loss = loss.item()
            input_size_r = list(args.input_size)
            delta_r = list(args.delta)
            torch.save(
                model.state_dict(),
                os.path.join(args.save_path,
                             (str(args.tag) +
                              'ffn_model_fov:{}_delta:{}_depth:{}.pth'.format(
                                  input_size_r[0], delta_r[0], args.depth))))
            torch.save(
                model.state_dict(),
                os.path.join(
                    args.save_path,
                    (str(args.tag) +
                     'ffn_model_fov:{}_delta:{}_depth:{}_pre{}_recall{}_.pth'.
                     format(input_size_r[0], delta_r[0], args.depth,
                            precision * 100, recall * 100))))

            print(
                'Precision: {:.2f}%, Recall: {:.2f}%, Accuracy: {:.2f}%, Model saved!'
                .format(precision * 100, recall * 100, accuracy * 100))

            buffer_step = 3000
            resume_step = args.resume_step - buffer_step
            if cnt > buffer_step:
                tb.add_scalar("Loss", loss.item(), cnt + resume_step)
                tb.add_scalar("Precision", precision * 100, cnt + resume_step)
                tb.add_scalar("Recall", recall * 100, cnt + resume_step)
                tb.add_scalar("Accuracy", accuracy * 100, cnt + resume_step)
Ejemplo n.º 6
0
def run():
  
  
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    # Horovod: initialize library.
    hvd.init()
    torch.manual_seed(args.seed)
    if args.cuda:
      
        # Horovod: pin GPU to local rank.
        torch.cuda.set_device(hvd.local_rank())
        torch.cuda.manual_seed(args.seed)
        
    # Horovod: limit # of CPU threads to be used per worker.
    torch.set_num_threads(1)
  
    """model_init"""
    model = FFN_no_norm(in_channels=4, out_channels=1, input_size=args.input_size, delta=args.delta, depth=args.depth)
  
    #hvd ddl
    # By default, Adasum doesn't need scaling up learning rate.
    lr_scaler = hvd.size() if not args.use_adasum else 1

    if args.cuda:
        # Move model to GPU.
        model.cuda()
        # If using GPU Adasum allreduce, scale learning rate by local_size.
        if args.use_adasum and hvd.nccl_built():
            lr_scaler = hvd.local_size()

    # Horovod: scale learning rate by lr_scaler.
    optimizer = optim.SGD(model.parameters(), lr=args.lr * lr_scaler)
  

    # Horovod: broadcast parameters & optimizer state.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    # Horovod: (optional) compression algorithm.
    compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

    # Horovod: wrap optimizer with DistributedOptimizer.
    optimizer = hvd.DistributedOptimizer(optimizer,
                                         named_parameters=model.named_parameters(),
                                         compression=compression,
                                         op=hvd.Adasum if args.use_adasum else hvd.Average)
  
  
  
  
  
    """resume"""
    if args.resume is not None:
        model.load_state_dict(torch.load(args.resume))
    
    if os.path.exists(args.save_path + 'resume_step.pkl'):
        resume = load_obj(args.save_path + 'resume_step.pkl')
    else:
        resume = {'resume_step': args.resume_step}
    args.resume_step = resume['resume_step']
    print('resume_step', args.resume_step)

    if args.tb == None:
        tb = SummaryWriter('./tensorboard/'+args.tag+'tb_train_log_fov:{}_delta:{}_depth:{}.pth'
                       .format(list(args.input_size)[0], list(args.delta)[0], args.depth))
    else:
        tb = SummaryWriter(args.tb)
        
    """data_load"""

    
    


    train_dataset= BatchCreator(args.train_data_dir, args.input_size, delta=args.delta,train=True)

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    # When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to prevent
    # issues with Infiniband implementations that are not fork-safe
    if (kwargs.get('num_workers', 0) > 0 and hasattr(mp, '_supports_context') and
            mp._supports_context and 'forkserver' in mp.get_all_start_methods()):
        kwargs['multiprocessing_context'] = 'forkserver'

    # Horovod: use DistributedSampler to partition the training data.
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
    train_loader = torch.utils.data.DataLoader(
        train_dataset, sampler=train_sampler, **kwargs)

    batch_it = get_batch(train_loader, args.batch_size, args.input_size,
                                     partial(fixed_offsets, fov_moves=train_dataset.shifts))

    """
    
    for index in range(files_total):
        input_h5data_dict = [(abs_path_training_data + sorted_files_train_data)]
        print(input_h5data_dict)
        train_dataset_dict = BatchCreator(input_h5data_dict, args.input_size, delta=args.delta, train=True)
        train_sampler_dict = torch.utils.data.distributed.DistributedSampler(train_dataset_dict, num_replicas=world_size, rank=rank, shuffle=True)
        train_loader_dict = DataLoader(train_dataset_dict, num_workers=0, sampler=train_sampler_dict , pin_memory=True)
        batch_it_dict = get_batch(train_loader_dict, args.batch_size, args.input_size,
                               partial(fixed_offsets, fov_moves=train_dataset_dict.shifts))
    """

    
    
    
    
    
    
    
    """optimizer"""
    """
    if args.opt == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    else:
        optimizer = optim.SGD(model.parameters(), lr=1e-3)
    """
    # optimizer = adabound.AdaBound(model.parameters(), lr=1e-3, final_lr=0.1)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.step, gamma=args.gamma, last_epoch=-1)
    
    """train_loop"""
    t_last = time.time()
    cnt = 0
    tp = fp = tn = fn = 0
    best_loss = np.inf
    
    model.train()

    while cnt < args.iter:
        cnt += 1
        
        # resume_tb
        if cnt % 1000 == 0:
            resume['resume_step'] = cnt + args.resume_step
            pickle_obj(resume, 'resume_step', args.save_path)
            
        """
        index_batch = (cnt % train_num)
        train_sampler_dict[index_batch].set_epoch(cnt)
        seeds, images, labels, offsets = next(batch_it_dict[index_batch])
        print(input_h5data_dict[index_batch])
        """
        

        train_sampler.set_epoch(cnt)
        seeds, images, labels, offsets = next(batch_it)

        
        
        
        # train
        t_curr = time.time()
        labels = labels.cuda()
        torch_seed = torch.from_numpy(seeds)
        input_data = torch.cat([images, torch_seed], dim=1)
        input_data = Variable(input_data.cuda())

        logits = model(input_data)
        updated = torch_seed.cuda() + logits

        optimizer.zero_grad()
        loss = F.binary_cross_entropy_with_logits(updated, labels)
        loss.backward()

        # torch.nn.utils.clip_grad_value_(model.parameters(), args.clip_grad_thr)
        optimizer.step()

        seeds[...] = updated.detach().cpu().numpy()
        
        
        

          
        pred_mask = (updated >= logit(0.8)).detach().cpu().numpy()
        true_mask = (labels > 0.5).cpu().numpy()
        true_bg = np.logical_not(true_mask)
        pred_bg = np.logical_not(pred_mask)
        tp += (true_mask & pred_mask).sum()
        fp += (true_bg & pred_mask).sum()
        fn += (true_mask & pred_bg).sum()
        tn += (true_bg & pred_bg).sum()
        precision = 1.0 * tp / max(tp + fp, 1)
        recall = 1.0 * tp / max(tp + fn, 1)
        accuracy = 1.0 * (tp + tn) / (tp + tn + fp + fn)
        print('[rank_{}:, Iter_{}:, loss: {:.4}, Precision: {:.2f}%, Recall: {:.2f}%, Accuracy: {:.2f}%]\r'.format(hvd.rank(),
            cnt, loss.item(), precision * 100, recall * 100, accuracy * 100))

        # scheduler.step()

        """model_saving_(iter)"""
        
        if (cnt % args.save_interval) == 0 and hvd.rank() == 0:
            tp = fp = tn = fn = 0
            # t_last = t_curr
            # best_loss = loss.item()
            input_size_r = list(args.input_size)
            delta_r = list(args.delta)
            torch.save(model.state_dict(), os.path.join(args.save_path, (
                        str(args.tag) + 'ffn_model_fov:{}_delta:{}_depth:{}.pth'.format(input_size_r[0],
                                                                                                    delta_r[0],
                                                                                                  args.depth))))
            torch.save(model.state_dict(), os.path.join(args.save_path, (
                    str(args.tag) + 'ffn_model_fov:{}_delta:{}_depth:{}_recall{}_.pth'.format(input_size_r[0],
                                                                                       delta_r[0],
                                                                                       args.depth,recall*100))))

            print('Precision: {:.2f}%, Recall: {:.2f}%, Accuracy: {:.2f}%, Model saved!'.format(
                precision * 100, recall * 100, accuracy * 100))

            buffer_step = 3000
            resume_step = args.resume_step - buffer_step
            if cnt > buffer_step:
                tb.add_scalar("Loss", loss.item(), cnt + resume_step)
                tb.add_scalar("Precision", precision * 100, cnt + resume_step)
                tb.add_scalar("Recall", recall * 100, cnt + resume_step)
                tb.add_scalar("Accuracy", accuracy * 100, cnt + resume_step)