def __init__(self, args):
        self.args = args

        # define Dataloader
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args)

        # define network
        model = DeepLab(num_classes=self.nclass, output_stride=args.out_stride)

        optimizer = torch.optim.Adam(model.parameters(), lr=args.base_lr, weight_decay=args.weight_decay)

        self.criterion = SegmentationLoss(weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        self.evaluator = Evaluator(self.nclass)
        self.best_pred = 0.0

        self.trainloss_history = []
        self.valloss_history = []

        self.train_plot = []
        self.val_plot = []
        # every 10 epochs the lr will multiply 0.1
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.base_lr, args.epochs, len(self.train_loader), lr_step=20)

        if args.cuda:
            self.model = self.model.cuda()
def testNetwork(images_folder, labels_folder, dictionary, target_classes, network_filename, output_folder):
    """
    Load a network and test it on the test dataset.
    :param network_filename: Full name of the network to load (PATH+name)
    """

    # TEST DATASET
    datasetTest = CoralsDataset(images_folder, labels_folder, dictionary, target_classes)
    datasetTest.disableAugumentation()

    classifier_info_filename = network_filename.replace(".net", ".json")
    output_classes = readClassifierInfo(classifier_info_filename, datasetTest)

    batchSize = 4
    dataloaderTest = DataLoader(datasetTest, batch_size=batchSize, shuffle=False, num_workers=0, drop_last=True,
                            pin_memory=True)

    # DEEPLAB V3+
    net = DeepLab(backbone='resnet', output_stride=16, num_classes=output_classes)
    net.load_state_dict(torch.load(network_filename))
    print("Weights loaded.")

    print("Test..")
    metrics_test, loss = evaluateNetwork(datasetTest, dataloaderTest, "NONE", None, [0.0], 0.0, 0.0, 0.0, 0, 0, 0,
                                         output_classes, net, True, output_folder)
    metrics_filename = network_filename[:len(network_filename) - 4] + "-test-metrics.txt"
    saveMetrics(metrics_test, metrics_filename)
    print("***** TEST FINISHED *****")
Пример #3
0
def main():
    model = DeepLab(output_stride=16, num_classes=21)
    train_loader = TrainDataset(data_root='/home/guomenghao/voc_aug/mydata/', split='train', batch_size=4, shuffle=True)
    val_loader = ValDataset(data_root='/home/guomenghao/voc_aug/mydata/', split='val', batch_size=1, shuffle=False)
    learning_rate = 0.005
    momentum = 0.9
    weight_decay = 1e-4
    optimizer = nn._SGD(model.parameters(), learning_rate, momentum, weight_decay)
    writer = SummaryWriter(os.path.join('curve', 'train.events.wo_drop'))
    epochs = 50
    evaluator = Evaluator(21)
    for epoch in range (epochs):
        train(model, train_loader, optimizer, epoch, learning_rate, writer)
        val(model, val_loader, epoch, evaluator, writer)
Пример #4
0
def main():
    #Place = paddle.fluid.CPUPlace()
    Place = paddle.fluid.CUDAPlace(0)
    with fluid.dygraph.guard(Place):
        transform = Transform(256)
        dataload = Dataloader(args.image_folder, args.image_list_file,
                              transform, True)
        train_load = fluid.io.DataLoader.from_generator(capacity=1,
                                                        use_multiprocess=False)
        train_load.set_sample_generator(dataload,
                                        batch_size=args.batch_size,
                                        places=Place)
        total_batch = int(len(dataload) / args.batch_size)

        if args.net == 'deeplab':
            model = DeepLab(59)
        else:
            print("Other model haven't finished....")

        costFunc = SegLoss
        adam = AdamOptimizer(learning_rate=args.lr,
                             parameter_list=model.parameters())

        for epoch in range(1, args.num_epochs + 1):
            train_loss = train(train_load, model, costFunc, adam, epoch,
                               total_batch)
            print(
                f"----- Epoch[{epoch}/{args.num_epochs}] Train Loss: {train_loss}"
            )

            if epoch % args.save_freq == 0 or epoch == args.num_epochs:
                model_path = os.path.join(
                    args.checkpoint_folder,
                    f"{args.net}-Epoch-{epoch}-Loss-{train_loss}")

                model_dict = model.state_dict()
                fluid.save_dygraph(model_dict, model_path)
                optimizer_dict = optimizer.state_dict()
                fluid.save_dygraph(optimizer_dict, model_path)
                print(f'----- Save model: {model_path}.pdparams')
                print(f'----- Save optimizer: {model_path}.pdopt')
Пример #5
0
def main():
    logging.basicConfig(format='[%(asctime)s]%(levelname)s:%(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.INFO)

    root_path = os.path.dirname(sys.argv[0])

    try:
        storage = RocksdbStorage(os.path.join(root_path, DATABASE_NAME))
    except TypeError:
        storage = RocksdbStorageV2(os.path.join(root_path, DATABASE_NAME))
    deeplab_inst = DeepLab(range(1), root_path, IMG_MEAN, storage)
    fst_inst = Fst(range(1), root_path, IMG_SHAPE, storage)
    server = Server(deeplab_inst, fst_inst, root_path, storage)

    try:
        server.run()
    except KeyboardInterrupt:
        pass
    finally:
        deeplab_inst.join_all()
        fst_inst.join_all()
        deeplab_inst.shutdown()
        fst_inst.shutdown()
                                    num_workers=2),
        'seresnext':
        torch.utils.data.DataLoader(image_datasets['seresnext'],
                                    batch_size=batch_size,
                                    shuffle=False,
                                    num_workers=2)
    }

    use_gpu = torch.cuda.is_available()
    ips = ['ip_2', 'ip_3', 'ip_4', 'ip_5', 'ip_7']
    models = {'drn': [], 'seresnext': []}
    for i in range(5):
        weight_path = os.path.join(
            current, 'result/resize/' + ips[i] + '/' + str(i + 1))
        weight_name = get_weight_name(weight_path)
        model = DeepLab(num_classes=5, backbone='drn54').eval()
        #model.load_state_dict(torch.load(os.path.join(current,"pretrained_models/model_13_2_2_2_epoch_580.pth")))
        #model.aspp.conv_1x1_4 = nn.Conv2d(256, 5, kernel_size=1)
        """
        for idx,p in enumerate(model.parameters()):
            if idx!=0:
                p.requires_grad = False
        """
        if use_gpu:
            #torch.distributed.init_process_group(backend="nccl")
            model = nn.DataParallel(model).to(device)
            #model = model.cuda()
            #print(model.module)

        model.load_state_dict(torch.load(weight_name))
        models['drn'].append(model)
Пример #7
0
def train_net(args):
    checkpoint = args.checkpoint
    start_epoch = 1
    best_loss = float('inf')
    writer = SummaryWriter(logdir=args.logdir)
    epochs_since_improvement = 0
    decays_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        torch.random.manual_seed(7)
        torch.cuda.manual_seed(7)
        np.random.seed(7)
        model = DIMModel(num_classes=1)
        if args.pretrained:
            migrate(model)
        model = nn.DataParallel(model)

        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom,
                                        weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=[args.beta1, args.beta2])
        start_epoch = args.start_epoch
    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']
        if 'torch_seed' in checkpoint:
            torch.random.set_rng_state(checkpoint['torch_seed'])
        else:
            torch.random.manual_seed(7)
        if 'torch_cuda_seed' in checkpoint:
            torch.cuda.set_rng_state(checkpoint['torch_cuda_seed'])
        else:
            torch.cuda.manual_seed(7)
        if 'np_seed' in checkpoint:
            np.random.set_state(checkpoint['np_seed'])
        else:
            np.random.seed(7)
        if 'python_seed' in checkpoint:
            random.setstate(checkpoint['python_seed'])
        else:
            random.seed(7)

    logger = get_logger()

    # Move to GPU, if available
    model = model.to(device)
    train_dataset = DIMDataset('train')
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=True)
    valid_dataset = DIMDataset('valid')
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8, pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        if args.optimizer == 'sgd' and epochs_since_improvement == 10:
            break

        if args.optimizer == 'sgd' and epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0:
            decays_since_improvement += 1
            print("\nDecays since last improvement: %d\n" % (decays_since_improvement,))
            adjust_learning_rate(optimizer, 0.6 ** decays_since_improvement)

        # One epoch's training
        train_loss = train(train_loader=train_loader,
                           model=model,
                           optimizer=optimizer,
                           epoch=epoch,
                           logger=logger)
        effective_lr = get_learning_rate(optimizer)
        print('Current effective learning rate: {}\n'.format(effective_lr))

        writer.add_scalar('Train_Loss', train_loss, epoch)
        writer.add_scalar('Learning_Rate', effective_lr, epoch)

        # One epoch's validation
        valid_loss = valid(valid_loader=valid_loader,
                           model=model,
                           epoch=epoch,
                           logger=logger)

        writer.add_scalar('Valid_Loss', valid_loss, epoch)

        # Check if there was an improvement
        is_best = valid_loss < best_loss
        best_loss = min(valid_loss, best_loss)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0
            decays_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best, args.checkpointdir)
Пример #8
0
    def __init__(self, args):
        self.args = args
        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)

        self.writer = self.summary.create_summary()

        # Define Dataloader

        kwargs = {'num_workers': args.workers, 'pin_memory': True}

        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network

        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)
        # Define Criterion
        # whether to use class balanced weights

        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                ROOT_PATH, args.dataset + '_classes_weights.npy')

            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)

            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(
                np.float32))  ##########weight not cuda

        else:
            weight = None

        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator

        self.evaluator = Evaluator(self.nclass)

        # Define lr scheduler

        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda

        if args.cuda:

            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)

            patch_replication_callback(self.model)

            self.model = self.model.cuda()

        # Resuming checkpoint

        self.best_pred = 0.0

        if args.resume is not None:

            if not os.path.isfile(args.resume):

                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))

            checkpoint = torch.load(args.resume)

            args.start_epoch = checkpoint['epoch']

            if args.cuda:

                self.model.module.load_state_dict(checkpoint['state_dict'])

            else:

                self.model.load_state_dict(checkpoint['state_dict'])

            if not args.ft:

                self.optimizer.load_state_dict(checkpoint['optimizer'])

            self.best_pred = checkpoint['best_pred']

            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning

        if args.ft:

            args.start_epoch = 0
Пример #9
0
        for x in ["train", "validation"]
    }

    dataloaders = {
        x: torch.utils.data.DataLoader(
            image_datasets[x],
            batch_size=2,
            #shuffle=x=="train",
            sampler=data_sampler[x],
            num_workers=2)
        for x in ['train', 'validation']
    }

    use_gpu = torch.cuda.is_available()

    model = DeepLab()
    #model.load_state_dict(torch.load(os.path.join(current,"pretrained_models/model_13_2_2_2_epoch_580.pth")))
    #model.aspp.conv_1x1_4 = nn.Conv2d(256, 20, kernel_size=1)
    """
    for idx,p in enumerate(model.parameters()):
        if idx!=0:
            p.requires_grad = False
    """
    if use_gpu:
        #torch.distributed.init_process_group(backend="nccl")
        model = nn.DataParallel(model).to(device)
        #model = model.cuda()
        print(model.module)

    #optimizer = optim.SGD(list(model.module.conv1.parameters())+list(model.module.fc.parameters()), lr=0.001)#, momentum=0.9)
Пример #10
0
                        shuffle=True,
                        num_workers=0)

voc_val = VOCSegmentation(args, base_dir=base_dir, mode='test')
val_dataloader = DataLoader(voc_val,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=0)
"""
3、定义模型

"""

model = DeepLab(num_classes=18,
                backbone="mobilenet",
                output_stride=8,
                sync_bn=False,
                freeze_bn=False)
"""
4、定义优化器

"""

train_params = model.parameters()
optimizer = torch.optim.SGD(train_params,
                            lr=0.003,
                            momentum=0.9,
                            weight_decay=5e-4,
                            nesterov=False)

scheduler = lr_scheduler.StepLR(optimizer, 100, gamma=0.5, last_epoch=-1)
Пример #11
0
def trainingNetwork(images_folder_train, labels_folder_train,
                    images_folder_val, labels_folder_val, dictionary,
                    target_classes, num_classes, save_network_as,
                    save_classifier_as, classifier_name, epochs, batch_sz,
                    batch_mult, learning_rate, L2_penalty,
                    validation_frequency, loss_to_use, epochs_switch,
                    epochs_transition, tversky_alpha, tversky_gamma, optimiz,
                    flagShuffle, experiment_name):

    ##### DATA #####

    # setup the training dataset
    datasetTrain = CoralsDataset(images_folder_train, labels_folder_train,
                                 dictionary, target_classes, num_classes)

    print("Dataset setup..", end='')
    datasetTrain.computeAverage()
    datasetTrain.computeWeights()
    print(datasetTrain.dict_target)
    print(datasetTrain.weights)
    freq = 1.0 / datasetTrain.weights
    print(freq)
    print("done.")

    writeClassifierInfo(save_classifier_as, classifier_name, datasetTrain)

    datasetTrain.enableAugumentation()

    datasetVal = CoralsDataset(images_folder_val, labels_folder_val,
                               dictionary, target_classes, num_classes)
    datasetVal.dataset_average = datasetTrain.dataset_average
    datasetVal.weights = datasetTrain.weights

    #AUGUMENTATION IS NOT APPLIED ON THE VALIDATION SET
    datasetVal.disableAugumentation()

    # setup the data loader
    dataloaderTrain = DataLoader(datasetTrain,
                                 batch_size=batch_sz,
                                 shuffle=flagShuffle,
                                 num_workers=0,
                                 drop_last=True,
                                 pin_memory=True)

    validation_batch_size = 4
    dataloaderVal = DataLoader(datasetVal,
                               batch_size=validation_batch_size,
                               shuffle=False,
                               num_workers=0,
                               drop_last=True,
                               pin_memory=True)

    training_images_number = len(datasetTrain.images_names)
    validation_images_number = len(datasetVal.images_names)

    print("NETWORK USED: DEEPLAB V3+")

    if os.path.exists(save_network_as):
        net = DeepLab(backbone='resnet',
                      output_stride=16,
                      num_classes=datasetTrain.num_classes)
        net.load_state_dict(torch.load(save_network_as))
        print("Checkpoint loaded.")
    else:
        ###### SETUP THE NETWORK #####
        net = DeepLab(backbone='resnet',
                      output_stride=16,
                      num_classes=datasetTrain.num_classes)
        state = torch.load("deeplab-resnet.pth.tar")
        # RE-INIZIALIZE THE CLASSIFICATION LAYER WITH THE RIGHT NUMBER OF CLASSES, DON'T LOAD WEIGHTS OF THE CLASSIFICATION LAYER
        new_dictionary = state['state_dict']
        del new_dictionary['decoder.last_conv.8.weight']
        del new_dictionary['decoder.last_conv.8.bias']
        net.load_state_dict(state['state_dict'], strict=False)

    # OPTIMIZER
    if optimiz == "SGD":
        optimizer = optim.SGD(net.parameters(),
                              lr=learning_rate,
                              weight_decay=L2_penalty,
                              momentum=0.9)
    elif optimiz == "ADAM":
        optimizer = optim.Adam(net.parameters(),
                               lr=learning_rate,
                               weight_decay=L2_penalty)
    elif optimiz == "QHADAM":
        pass

    USE_CUDA = torch.cuda.is_available()

    if USE_CUDA:
        device = torch.device("cuda")
        net.to(device)

    ##### TRAINING LOOP #####

    # Writer will output to ./runs/ directory by default
    writer = SummaryWriter(comment=experiment_name)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     patience=2,
                                                     verbose=True)

    best_accuracy = 0.0
    best_jaccard_score = 0.0

    # Crossentropy loss
    weights = datasetTrain.weights
    class_weights = torch.FloatTensor(weights).cuda()
    CEloss = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

    # weights for GENERALIZED DICE LOSS (GDL)
    freq = 1.0 / datasetTrain.weights[1:]
    w = 1.0 / (freq * freq)
    w = w / w.sum() + 0.00001
    w_for_GDL = torch.from_numpy(w)
    w_for_GDL = w_for_GDL.to(device)

    # Focal Tversky loss
    focal_tversky_gamma = torch.tensor(tversky_gamma)
    focal_tversky_gamma = focal_tversky_gamma.to(device)

    tversky_loss_alpha = torch.tensor(tversky_alpha)
    tversky_loss_beta = torch.tensor(1.0 - tversky_alpha)
    tversky_loss_alpha = tversky_loss_alpha.to(device)
    tversky_loss_beta = tversky_loss_beta.to(device)

    print("Training Network")
    for epoch in range(epochs):  # loop over the dataset multiple times

        net.train()
        optimizer.zero_grad()

        writer.add_scalar('LR/train', optimizer.param_groups[0]['lr'], epoch)

        running_loss = 0.0
        for i, minibatch in enumerate(dataloaderTrain):
            # get the inputs
            images_batch = minibatch['image']
            labels_batch = minibatch['labels']

            if USE_CUDA:
                images_batch = images_batch.to(device)
                labels_batch = labels_batch.to(device)

            # forward+loss+backward
            outputs = net(images_batch)

            loss = computeLoss(loss_to_use, CEloss, w_for_GDL,
                               tversky_loss_alpha, tversky_loss_beta,
                               focal_tversky_gamma, epoch, epochs_switch,
                               epochs_transition, labels_batch, outputs)

            loss.backward()

            # TO AVOID MEMORY TROUBLE UPDATE WEIGHTS EVERY BATCH SIZE X BATCH MULT
            if (i + 1) % batch_mult == 0:
                optimizer.step()
                optimizer.zero_grad()

            print(epoch, i, loss.item())
            running_loss += loss.item()

        print("Epoch: %d , Running loss = %f" % (epoch, running_loss))

        ### VALIDATION ###
        if epoch > 0 and (epoch + 1) % validation_frequency == 0:

            print("RUNNING VALIDATION.. ", end='')

            metrics_val, mean_loss_val = evaluateNetwork(
                datasetVal,
                dataloaderVal,
                loss_to_use,
                CEloss,
                w_for_GDL,
                tversky_loss_alpha,
                tversky_loss_beta,
                focal_tversky_gamma,
                epoch,
                epochs_switch,
                epochs_transition,
                datasetVal.num_classes,
                net,
                flag_compute_mIoU=False)
            accuracy = metrics_val['Accuracy']
            jaccard_score = metrics_val['JaccardScore']

            scheduler.step(mean_loss_val)

            metrics_train, mean_loss_train = evaluateNetwork(
                datasetTrain,
                dataloaderTrain,
                loss_to_use,
                CEloss,
                w_for_GDL,
                tversky_loss_alpha,
                tversky_loss_beta,
                focal_tversky_gamma,
                epoch,
                epochs_switch,
                epochs_transition,
                datasetTrain.num_classes,
                net,
                flag_compute_mIoU=False)
            accuracy_training = metrics_train['Accuracy']
            jaccard_training = metrics_train['JaccardScore']

            writer.add_scalar('Loss/train', mean_loss_train, epoch)
            writer.add_scalar('Loss/validation', mean_loss_val, epoch)
            writer.add_scalar('Accuracy/train', accuracy_training, epoch)
            writer.add_scalar('Accuracy/validation', accuracy, epoch)

            #if jaccard_score > best_jaccard_score:
            if accuracy > best_accuracy:

                best_accuracy = accuracy
                best_jaccard_score = jaccard_score
                torch.save(net.state_dict(), save_network_as)
                # performance of the best accuracy network on the validation dataset
                metrics_filename = save_network_as[:len(save_network_as) -
                                                   4] + "-val-metrics.txt"
                saveMetrics(metrics_val, metrics_filename)
                metrics_filename = save_network_as[:len(save_network_as) -
                                                   4] + "-train-metrics.txt"
                saveMetrics(metrics_train, metrics_filename)

            print("-> CURRENT BEST ACCURACY ", best_accuracy)

    # main loop ended - reload it and evaluate mIoU
    torch.cuda.empty_cache()
    del net
    net = None

    print("Final evaluation..")
    net = DeepLab(backbone='resnet',
                  output_stride=16,
                  num_classes=datasetTrain.num_classes)
    net.load_state_dict(torch.load(save_network_as))

    metrics_val, mean_loss_val = evaluateNetwork(datasetVal,
                                                 dataloaderVal,
                                                 loss_to_use,
                                                 CEloss,
                                                 w_for_GDL,
                                                 tversky_loss_alpha,
                                                 tversky_loss_beta,
                                                 focal_tversky_gamma,
                                                 epoch,
                                                 epochs_switch,
                                                 epochs_transition,
                                                 datasetVal.num_classes,
                                                 net,
                                                 flag_compute_mIoU=True)

    writer.add_hparams(
        {
            'LR': learning_rate,
            'Decay': L2_penalty,
            'Loss': loss_to_use,
            'Transition': epochs_transition,
            'Gamma': tversky_gamma,
            'Alpha': tversky_alpha
        }, {
            'hparam/Accuracy': best_accuracy,
            'hparam/mIoU': best_jaccard_score
        })

    writer.close()

    print("***** TRAINING FINISHED *****")
    print("BEST ACCURACY REACHED ON THE VALIDATION SET: %.3f " % best_accuracy)
Пример #12
0
args.base_size = 513
args.crop_size = 513

voc_train = VOCSegmentation(args, base_dir=base_dir, mode='train')
dataloader = DataLoader(voc_train, batch_size=2, shuffle=True, num_workers=0)

voc_val = VOCSegmentation(args, base_dir=base_dir, mode='test')
val_dataloader = DataLoader(voc_val, batch_size=2, shuffle=True, num_workers=0)
"""
3、定义模型

"""
#model = Unet(in_channels=3,n_classes=18)
model = DeepLab(num_classes=18,
                backbone="resnet",
                output_stride=8,
                sync_bn=False,
                freeze_bn=False)
"""
8、使用预定义模型,并判断是否使用cuda
"""

if os.path.isdir(model_path):
    try:
        checkpoint = torch.load(model_path + 'model_deeplab.t7',
                                map_location='cpu')
        model.load_state_dict(checkpoint['state'])
        start_epoch = checkpoint['epoch']
        print('===> Load last checkpoint data')
    except FileNotFoundError:
        print('Can\'t found model_deeplab.t7')
Пример #13
0
    #                 'validation':dataset_sizes['validation']}

    sample_weight={x:[1]*dataset_sizes[x] for x in ['train', 'validation']}
    data_sampler={x:torch.utils.data.sampler.WeightedRandomSampler(sample_weight[x], num_samples=num_samples[x], replacement=True) for x in ["train","validation"]}

    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                              batch_size=batch_size,
                                              #shuffle=x=="train",
                                              sampler=data_sampler[x],
                                              worker_init_fn=np.random.seed(seed),
                                              num_workers=2) for x in ['train', 'validation']}

    use_gpu = torch.cuda.is_available()


    model=DeepLab(num_classes=5,backbone='seresnext50')
    #model.load_state_dict(torch.load(os.path.join(current,"pretrained_models/model_13_2_2_2_epoch_580.pth")))
    #model.aspp.conv_1x1_4 = nn.Conv2d(256, 5, kernel_size=1)
    """
    for idx,p in enumerate(model.parameters()):
        if idx!=0:
            p.requires_grad = False
    """
    if use_gpu:
        #torch.distributed.init_process_group(backend="nccl")
        model = nn.DataParallel(model).to(device)
        #model = model.cuda()
        #print(model.module)
    if parsed.weight_name:
        model.load_state_dict(torch.load(os.path.join(current,parsed.weight_name)))
    #optimizer = optim.SGD(list(model.module.conv1.parameters())+list(model.module.fc.parameters()), lr=0.001)#, momentum=0.9)