Exemplo n.º 1
0
def main(args):
    dataset = get_data_loader(
        args.files_pattern,
        batch_size=args.batch_size,
        num_workers=args.num_data_workers,
        crop_size=args.crop_size if args.crop_size != 0 else None)

    device = torch.cuda.current_device()
    if args.dataloader_only and dist.get_rank() == 0:
        print("Running in dataloader_only mode ...")
    else:
        model = UNet().to(device)
        if dist.is_initialized():
            model = DistributedDataParallel(model,
                                            device_ids=[args.local_rank],
                                            output_device=[args.local_rank])

        if args.forward_only:
            if dist.get_rank() == 0:
                print("Running in inference (forward only) mode ...")
        else:
            if dist.get_rank() == 0:
                print("Running in training (forward/backward) mode ...")
            optimizer = torch.optim.Adam(model.parameters())

    total_time = 0
    for epoch in range(args.epochs + 1):
        if dist.get_rank() == 0:
            print("epoch", epoch)
        t_start = time.time()
        for idx, data in enumerate(dataset):
            if idx > args.max_batches_per_epoch:
                break

            if args.dataloader_only:
                continue

            inp, tar = map(lambda x: x.to(device), data)

            if args.forward_only:
                model.eval()
                gen = model(inp)
            else:
                model.zero_grad()
                model.train()
                gen = model(inp)
                loss = torch.nn.functional.l1_loss(gen, tar)
                loss.backward()
                optimizer.step()

        if epoch > 0:
            total_time += time.time() - t_start

    n_batches = min(args.max_batches_per_epoch + 1, len(dataset))
    if dist.get_rank() == 0:
        print("Timing:",
              float(args.batch_size * n_batches * args.epochs) / (total_time),
              "samples/s")
Exemplo n.º 2
0
def main(config):

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    train_dataset, valid_dataset = generate_datasets(
        config['data_dir'],
        valid_ids=config['val_ids'],
        load_in_memory=config['load_in_memory'])

    print(f'Length of training dataset: {len(train_dataset)}')
    print(f'Length of training dataset: {len(valid_dataset)}')

    # TODO: define and add data augmentation + image normalization
    # train_dataset.transform = train_transform
    # valid_dataset.transform = valid_transform
    transforms = A.Compose([
        A.Normalize(),  # TODO: change values
        ToTensorV2()
    ])
    train_dataset.transform = transforms
    valid_dataset.transform = transforms

    train_loader = DataLoader(train_dataset,
                              batch_size=config['batch_size'],
                              shuffle=True,
                              num_workers=config['num_workers'])
    valid_loader = DataLoader(valid_dataset,
                              config['batch_size'],
                              shuffle=False,
                              num_workers=config['num_workers'])
    model = UNet()
    model = model.to(device)

    criterion = config['criterion']
    optimizer = torch.optim.Adam(params=model.parameters(), lr=5 - 4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=10,
                                                gamma=0.5)
    metrics = [iou_score, dice_score]

    trainer = Trainer(model=model,
                      criterion=criterion,
                      metrics=metrics,
                      optimizer=optimizer,
                      lr_scheduler=scheduler,
                      config=config,
                      train_loader=train_loader,
                      val_loader=valid_loader,
                      device=device)
    trainer.train()
    return model
Exemplo n.º 3
0
def choose_model(model_name, phase):
    if model_name == 'ADNet':
        return ADNET(3, phase)
    elif model_name == 'DnCNN':
        return DNCNN(3, phase)
    elif model_name == 'BRDNet':
        return BRDNET(3, phase)
    elif model_name == 'feb_rfb_ab_mish_a_add':
        return feb_rfb_ab_mish_a_add(3, phase)
    elif 'SXNet' in model_name:
        info = model_name.split('_')
        basic_ch, rfb_ch, asy, tlu, bias = int(info[1]), int(
            info[2]), False, False, False
        if '_a' in model_name:
            asy = True
        if '_t' in model_name:
            tlu = True
        if '_b' in model_name:
            bias = True
        return SXNet(basic_ch, rfb_ch, 3, phase, asy=asy, bias=bias, tlu=tlu)
    elif 'MKM' == model_name:
        return DNCNN_based(3, 'mkm', phase)
    elif 'RM' == model_name:
        return DNCNN_based(3, 'rm', phase)
    elif 'MKM_RM' == model_name:
        return DNCNN_based(3, 'mkm', phase)
    elif 'Vanilla' == model_name:
        return DNCNN_based(3, 'vanilla', phase)
    elif model_name == 'UNet':
        return UNet(3, phase)
Exemplo n.º 4
0
 def _create_model(self, model_type):
     if model_type == 'ar_lstm':
         model = AR_LSTM(self.args.num_input_frames,
                         self.args.num_output_frames,
                         self.args.reinsert_frequency, self.device)
     elif model_type == 'convlstm':
         model = get_convlstm_model(self.args.num_input_frames,
                                    self.args.num_output_frames,
                                    self.args.batch_size,
                                    self.device,
                                    dilation=1,
                                    padding=1)
     elif model_type == 'dilated_convlstm':
         model = get_convlstm_model(self.args.num_input_frames,
                                    self.args.num_output_frames,
                                    self.args.batch_size,
                                    self.device,
                                    dilation=2,
                                    padding=2)
     elif model_type == 'predrnn':
         model = PredRNNPP(self.args.num_input_frames,
                           self.args.num_output_frames,
                           self.device,
                           use_GHU=False)
     elif model_type == 'predrnn_ghu':
         model = PredRNNPP(self.args.num_input_frames,
                           self.args.num_output_frames,
                           self.device,
                           use_GHU=True)
     elif model_type == 'resnet':
         model = resnet12(self.args.num_input_frames,
                          self.args.num_output_frames)
     elif model_type == 'resnet_dilated':
         model = resnet12(self.args.num_input_frames,
                          self.args.num_output_frames,
                          replace_stride_with_dilation=[1, 2, 4])
     elif model_type == 'unet':
         model = UNet(self.args.num_input_frames,
                      self.args.num_output_frames,
                      isize=64)
     elif model_type == 'unet_small':
         model = UNet(self.args.num_input_frames,
                      self.args.num_output_frames,
                      isize=16)
     else:
         raise Warning('Not supported model')
     return model
Exemplo n.º 5
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    if args.gpu == -1:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:{}'.format(args.gpu))

    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)

    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    genotype = eval("genotypes.%s" % args.arch)
    #model = Network(args.init_channels, dataset_classes, args.layers, args.auxiliary, genotype)
    model = Network(args)
    model = model.to(device)
    util.load(model, args.model_path)

    logging.info("param size = %fMB", util.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    test_data = MyDataset(args=args, subset='test')

    test_queue = torch.utils.data.DataLoader(test_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=2)

    model.drop_path_prob = args.drop_path_prob
    test_acc, test_obj, test_fscore, test_MIoU = infer(test_queue, model,
                                                       criterion)
    logging.info('test_acc %f _fscores %f_MIoU %f', test_acc, test_fscore,
                 test_MIoU)
def main():
    parser = argparse.ArgumentParser(
        description='Visualize segmentations obtained from trained UNet')
    parser.add_argument('checkpoint',
                        type=str,
                        help='Path to UNet model checkpoint')
    parser.add_argument(
        'img_path',
        type=str,
        help='Path to image or directory containing images to segment')
    parser.add_argument('-r',
                        '--resize',
                        type=int,
                        default=0,
                        help='Resize size of the image prior to segmentation')
    parser.add_argument('-o',
                        '--out',
                        type=str,
                        default='./',
                        help='Path to write segmentation visualizations to')
    args = parser.parse_args()

    if not os.path.exists(args.checkpoint):
        sys.exit('Specified checkpoint cannot be found')

    if not os.path.exists(args.img_path):
        sys.exit('Images for segmentation could not be found')

    imgs = []
    if os.path.isdir(args.img_path):
        for file in os.listdir(args.img_path):
            if os.path.isfile(os.path.join(args.img_path, file)):
                imgs.append(os.path.join(args.img_path, file))
    else:
        imgs.append(args.img_path)

    checkpoint = torch.load(args.checkpoint,
                            map_location=lambda storage, loc: storage)
    model = UNet(num_classes=len(datasets.Cityscapes.classes))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    visualizer = CityscapeSegmentationVis(model,
                                          input_image_transform(args.resize))

    for img in imgs:
        image_name = 'segmentation_' + img.split('/')[-1]
        out_location = os.path.join(args.out, image_name)
        class_tensor = visualizer.get_predicted_segmentation(img)
        visualizer.save_segmentation(class_tensor, out_location)
Exemplo n.º 7
0
    def compile(self):

        if self.compiled:
            print('Model already compiled.')
            return
        self.compiled = True

        # Placeholders.
        self.X = tf.placeholder(tf.float32, shape=(None, 32, 32, 1), name='X')
        self.Y = tf.placeholder(tf.float32, shape=(None, 32, 32, 2), name='Y')

        # U-Net.
        net = UNet(self.seed)
        self.out = net.forward(self.X)

        # Loss and metrics.
        # TODO: try with MAE.
        self.loss = tf.keras.losses.MeanSquaredError()(self.Y, self.out)

        # Global step.
        self.global_step = tf.Variable(0, trainable=False, name='Global_Step')

        # Learning rate.
        if self.learning_rate_decay:
            self.lr = tf.train.exponential_decay(
                self.learning_rate,
                self.global_step,
                self.learning_rate_decay_steps,
                self.learning_rate_decay_rate,
                name='learning_rate_decay')
        else:
            self.lr = tf.constant(self.learning_rate)

        # Optimizer.
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate=self.lr).minimize(self.loss,
                                            global_step=self.global_step)

        # Sampler.
        gen_sample = UNet(self.seed, is_training=False)
        self.sampler = gen_sample.forward(self.X, reuse_vars=True)

        # Tensorboard.
        tf.summary.scalar('loss', self.loss)

        self.saver = tf.train.Saver()
Exemplo n.º 8
0
                                            momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer_factory = optimizer_setup(optim.Adam, lr=learning_rate)

    if args.model == 'CnnVanilla':
        model = CnnVanilla(num_classes=10)
    elif args.model == 'AlexNet':
        model = AlexNet(num_classes=10)
    elif args.model == 'VggNet':
        model = VggNet(num_classes=10)
    elif args.model == 'ResNet':
        model = ResNet(num_classes=10)
    elif args.model == 'IFT725Net':
        model = IFT725Net(num_classes=10)
    elif args.model == 'UNet':
        model = UNet(num_classes=4)
        args.dataset = 'acdc'

        if data_augment:
            train_set = HDF5Dataset('train',
                                    hdf5_file,
                                    transform=acdc_augment_transform)

        else:
            train_set = HDF5Dataset('train',
                                    hdf5_file,
                                    transform=acdc_base_transform)

        test_set = HDF5Dataset('test',
                               hdf5_file,
                               transform=acdc_base_transform)
                                                union=union)

    # calculate IoU over full set
    IoU = calculate_IoU(intersection, union, n_classes=34)
    IoU_dict, IoU_average = calculate_average_IoU(IoU)

    print('IoU per class: ')
    for key, value in IoU_dict.items():
        print(key, ' : ', value)
    print('IoU average for 34 classes: ', IoU_average)
    IoU_19_average = calculate_IoU_train_classes(IoU)
    print('IoU average for 19 classes: ', IoU_19_average)

    return


if __name__ == '__main__':
    '''model'''
    model = UNet(n_classes=34,
                 depth=4,
                 wf=3,
                 batch_norm=True,
                 padding=True,
                 up_mode='upconv')

    model_weights = file = torch.load('weights/unet-id1.pt')

    evaluate(model_weights=model_weights,
             model=model,
             dataset='val',
             batch_size=4)
Exemplo n.º 10
0
def train(gpu, args):

    # rank = args.nr * args.num_gpus + gpu

    # dist.init_process_group(backend="nccl", world_size=args.world_size, rank=rank)

    print("Using {} percent of the target data for finetuning".format(
        args.training_frac * 100))

    if args.batch_size == 1 and args.use_bn is True:
        raise Exception

    torch.autograd.set_detect_anomaly(True)
    torch.manual_seed(args.torch_seed)
    torch.cuda.manual_seed(args.cuda_seed)

    torch.cuda.set_device(gpu)

    DATASET_NAME = args.dataset_name
    DATA_ROOT = args.data_root
    OVERLAYS_ROOT = args.overlays_root

    if args.model_name == 'dss':
        model = DSS_Net(args, n_channels=3, n_classes=1, bilinear=True)
        loss = FocalLoss()
    elif args.model_name == 'unet':
        model = UNet(args)
        loss = nn.BCELoss()
    else:
        raise NotImplementedError

    #model = nn.SyncBatchNorm(model)

    print(f"Using {torch.cuda.device_count()} GPUs...")

    # define dataset
    if DATASET_NAME == 'synthetic':
        assert (args.overlays_root != "")
        train_dataset = SmokeDataset(dataset_limit=args.num_examples,
                                     training=True,
                                     training_frac=args.training_frac)
        # train_dataset = SyntheticSmokeTrain(args, DATA_ROOT, OVERLAYS_ROOT, dataset_limit=args.num_examples)
        # train_sampler = DistributedSampler(train_dataset, num_replicas=args.world_size, rank=rank, shuffle=True)
        train_dataloader = DataLoader(
            train_dataset,
            args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True)  #, sampler=train_sampler)
        if args.validate:
            val_dataset = SmokeDataset()
        else:
            val_dataset = None
        val_dataloader = DataLoader(val_dataset,
                                    args.batch_size,
                                    shuffle=False,
                                    num_workers=args.num_workers,
                                    pin_memory=True) if val_dataset else None
    else:
        raise NotImplementedError

    # define augmentations
    augmentations = None  #SyntheticAugmentation(args)

    # load the model
    print("Loding model and augmentations and placing on gpu...")

    if args.cuda:
        if augmentations is not None:
            augmentations = augmentations.cuda()

        model = model.cuda(device=gpu)

        # if args.num_gpus > 0 or torch.cuda.device_count() > 0:
        #     model = DistributedDataParallel(model, device_ids=[gpu])

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"The model has {num_params} learnable parameters")

    # load optimizer and lr scheduler
    optimizer = Adam(model.parameters(),
                     lr=args.lr,
                     betas=[args.momentum, args.beta],
                     weight_decay=args.weight_decay)

    if args.lr_sched_type == 'plateau':
        print("Using plateau lr schedule")
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=args.lr_gamma,
            verbose=True,
            mode='min',
            patience=10)
    elif args.lr_sched_type == 'step':
        print("Using step lr schedule")
        milestones = [30]
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=milestones, gamma=args.lr_gamma)
    elif args.lr_sched_type == 'none':
        lr_scheduler = None

    # set up logging
    # if not args.no_logging and gpu == 0:
    if not os.path.isdir(args.log_dir):
        os.mkdir(args.log_dir)
    log_dir = os.path.join(args.log_dir, args.exp_dir)
    if not os.path.isdir(log_dir):
        os.mkdir(log_dir)
    if args.exp_name == "":
        exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d")
    else:
        exp_name = args.exp_name
    log_dir = os.path.join(log_dir, exp_name)
    writer = SummaryWriter(log_dir)

    if args.ckpt != "" and args.use_pretrained:
        state_dict = torch.load(args.ckpt)  #['state_dict']
        model.load_state_dict(copy_state_dict(state_dict))
    elif args.start_epoch > 0:
        load_epoch = args.start_epoch - 1
        ckpt_fp = os.path.join(log_dir, f"{load_epoch}.ckpt")

        print(f"Loading model from {ckpt_fp}...")

        ckpt = torch.load(ckpt_fp)
        assert (ckpt['epoch'] == load_epoch
                ), "epoch from state dict does not match with args"
        model.load_state_dict(ckpt)

    model.train()

    # run training loop
    for epoch in range(args.start_epoch, args.epochs + 1):
        print(f"Training epoch: {epoch}...")
        # train_sampler.set_epoch(epoch)
        freeze_layers = epoch < 10
        train_loss_avg, pred_mask, input_dict = train_one_epoch(
            args, model, loss, train_dataloader, optimizer, augmentations,
            lr_scheduler, freeze_layers)
        if gpu == 0:
            print(f"\t Epoch {epoch} train loss avg:")
            pprint(train_loss_avg)

        if val_dataset is not None:
            print(f"Validation epoch: {epoch}...")
            val_loss_avg = eval(args, model, loss, val_dataloader,
                                augmentations)
            print(f"\t Epoch {epoch} val loss avg: {val_loss_avg}")

        if not args.no_logging and gpu == 0:
            writer.add_scalar(f'loss/train', train_loss_avg, epoch)
            if epoch % args.log_freq == 0:
                visualize_output(args, input_dict, pred_mask, epoch, writer)

        if args.lr_sched_type == 'plateau':
            lr_scheduler.step(train_loss_avg_dict['total_loss'])
        elif args.lr_sched_type == 'step':
            lr_scheduler.step(epoch)

        # save model
        if not args.no_logging:
            if epoch % args.save_freq == 0 or epoch == args.epochs:
                fp = os.path.join(log_dir, f"finetune_{epoch}.ckpt")
                print("saving model to: ", fp)
                torch.save(model.state_dict(), fp)

            writer.flush()

    return
Exemplo n.º 11
0
    # from one_hot vector to categorical
    prediction = np.argmax(prediction, axis=0)
    # convert the predicted mask to rgb image
    prediction = convert_mask_to_rgb_image(prediction)
    # remove the batch dim and the channel dim of img
    img = np.squeeze(np.squeeze(img))
    # convert img to a numpy array
    if use_cuda and torch.cuda.is_available():
        img = img.cpu().numpy()
    else:
        img = img.numpy()

    return prediction


model = UNet(num_classes=4, in_channels=3)
model.load_weights('UNet.pt')

if len(sys.argv) > 1:
    use_cuda = (sys.argv[1] == 'True')
else:
    use_cuda = True

print("model loaded!")

device_name = 'cuda:0' if use_cuda else 'cpu'
if use_cuda and not torch.cuda.is_available():
    warnings.warn("CUDA is not available. Suppress this warning by passing "
                  "use_cuda=False.")
    device_name = 'cpu'
Exemplo n.º 12
0
    '''data'''    
    data_generator = load_data('datasets/citys', batch_size=batch_size, shuffle=True)
        
    '''device''' 
    no_cuda = False
    use_cuda = not no_cuda and torch.cuda.is_available()
    device = torch.device('cuda:0' if use_cuda else 'cpu')
    print('using device:', device)
    
    weights = torch.FloatTensor(list(WEIGHTS.values()))
    weights = weights.to(device)
    
    '''model'''
    model = UNet(n_classes=34,
                 depth=4,
                 wf=3,
                 batch_norm=True,
                 padding=True,
                 up_mode='upconv').to(device)
    
#    from torchsummary import summary
#    summary(model, input_size=(3, 1024, 2048))
    
    '''training'''
    optim = torch.optim.Adam(model.parameters(),
                             lr=lr_init,
                             betas=(0.9, 0.999),
                             eps=1e-08,
                             weight_decay=5e-4)    
    
    from torch.optim.lr_scheduler import StepLR
    scheduler = StepLR(optim, step_size=lr_step_size, gamma=gamma)
Exemplo n.º 13
0
def channel_vis_driver(model_name, checkpoint_path, data_path, dataset, conv_layer, channels, init_img_size, upscale_steps, upscale_factor, lr, update_steps, grid, out_path, verbose):
    if model_name == 'unet':
        if not os.path.exists(checkpoint_path):
            sys.exit('Specified checkpoint cannot be found')

        checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
        model = UNet(num_classes=len(datasets.Cityscapes.classes), encoder_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
    elif model_name == 'vggmod':
        model = models.vgg11(pretrained=True)
    else:
        sys.exit('No model provided, please specify --unet or --vgg11 to analyze the UNet or VGG11 encoder, respectively')

    # Set model to evaluation mode and fix the parameter values
    model.eval()
    for param in model.parameters():
        param.requires_grad_(False)

    layer = get_conv_layer(model, conv_layer)

    analyzer = LayerActivationAnalysis(model, layer)


    # Save a grid of channel activation visualizations
    if grid:
        if not channels:
            # Get a random sample of 9 activated channels
            channels = analyzer.get_activated_filter_indices()
            np.random.shuffle(channels)
            channels = channels[:9]

        imgs = []
        for i, channel in enumerate(channels):
            if verbose:
                print('Generating image {} of {}...'.format(i+1, len(channels)))

            img = analyzer.get_max_activating_image(channel,
                                                    initial_img_size=init_img_size,
                                                    upscaling_steps=upscale_steps,
                                                    upscaling_factor=upscale_factor,
                                                    lr=lr,
                                                    update_steps=update_steps,
                                                    verbose=verbose)

            imgs.append(img)
        channel_string = '-'.join(str(channel_id) for channel_id in channels)
        output_dest = os.path.join(out_path, '{}_layer{}_channels{}.png'.format(model_name, conv_layer, channel_string))
        save_image_grid(imgs, output_dest)

    # Save a channel activation visualization for each specified channel
    elif channels is not None:
        for channel in channels:

            img = analyzer.get_max_activating_image(channel,
                                                    initial_img_size=init_img_size,
                                                    upscaling_steps=upscale_steps,
                                                    upscaling_factor=upscale_factor,
                                                    lr=lr,
                                                    update_steps=update_steps,
                                                    verbose=verbose)

            output_dest = os.path.join(out_path, '{}_layer{}_channel{}.png'.format(model_name, conv_layer, channel))
            save_image(img, output_dest)

    else:
        # Compute the average number number of channels activated in each layer
        if data_path and dataset:
            layers = [get_conv_layer(model, i) for i in [1,2,3,4,5,6,7,8]]
            avg = analyzer.get_avg_activated_channels(layers, data_path, dataset, 100)
            print('Average number of channels activated per convolutional layer: {}'.format(avg))

        # Output the channels activated by a randomly initialize image
        else:
            activated_channels = analyzer.get_activated_filter_indices(initial_img_size=init_img_size)
            print('Output channels in conv layer {} activated by random image input:'.format(conv_layer))
            print(activated_channels)
            print()
            print('(Total of {} activated channels)'.format(len(activated_channels)))
Exemplo n.º 14
0
def main():

    args = get_arguments()

    # configuration
    CONFIG = Dict(yaml.safe_load(open(args.config)))

    # writer
    if CONFIG.writer_flag:
        writer = SummaryWriter(CONFIG.result_path)
    else:
        writer = None


    """ DataLoader """
    labeled_train_data = PartAffordanceDataset(CONFIG.labeled_data,
                                                config=CONFIG,
                                                transform=transforms.Compose([
                                                        CenterCrop(CONFIG),
                                                        ToTensor(),
                                                        Normalize()
                                                ]))

    if CONFIG.train_mode == 'semi':
        unlabeled_train_data = PartAffordanceDatasetWithoutLabel(CONFIG.unlabeled_data,
                                                                config=CONFIG,
                                                                transform=transforms.Compose([
                                                                    CenterCrop(CONFIG),
                                                                    ToTensor(),
                                                                    Normalize()
                                                                ]))
    else:
        unlabeled_train_data = None

    test_data = PartAffordanceDataset(CONFIG.test_data,
                                    config=CONFIG,
                                    transform=transforms.Compose([
                                        CenterCrop(CONFIG),
                                        ToTensor(),
                                        Normalize()
                                    ]))

    train_loader_with_label = DataLoader(labeled_train_data, batch_size=CONFIG.batch_size, shuffle=True, num_workers=CONFIG.num_workers)
    
    if unlabeled_train_data is not None:
        train_loader_without_label = DataLoader(unlabeled_train_data, batch_size=CONFIG.batch_size, shuffle=True, num_workers=CONFIG.num_workers)
    
    test_loader = DataLoader(test_data, batch_size=CONFIG.batch_size, shuffle=False, num_workers=CONFIG.num_workers)


    """ model """
    if CONFIG.model == 'FCN8s':
        model = FCN8s(CONFIG.in_channel, CONFIG.n_classes)
    elif CONFIG.model == 'SegNetBasic':
        model = SegNetBasic(CONFIG.in_channel, CONFIG.n_classes)
    elif CONFIG.model == 'UNet':
        model = UNet(CONFIG.in_channel, CONFIG.n_classes)
    else:
        print('This model doesn\'t exist in the model directory')
        sys.exit(1)


    if CONFIG.train_mode == 'full':
        model.apply(init_weights)
        model.to(args.device)
    elif CONFIG.train_mode == 'semi':
        if CONFIG.pretrain_model is not None:
            torch.load(CONFIG.pretrain_model, map_location=lambda storage, loc: storage)
        else:
            model.apply(init_weights)

        model.to(args.device)
        if CONFIG.model_d == 'FCDiscriminator':
            model_d = FCDiscriminator(CONFIG)
        else:
            model_d = Discriminator(CONFIG)

        model_d.apply(init_weights)
        model_d.to(args.device)
    else:
        print('This training mode doesn\'t exist.')
        sys.exit(1)



    """ class weight after center crop. See dataset.py """
    if CONFIG.class_weight_flag:
        class_num = torch.tensor([2078085712, 34078992, 15921090, 12433420, 
                                    38473752, 6773528, 9273826, 20102080])

        total = class_num.sum().item()

        frequency = class_num.float() / total
        median = torch.median(frequency)

        class_weight = median / frequency
        class_weight = class_weight.to(args.device)
    else:
        class_weight = None



    """ supplementary constant for discriminator """
    if CONFIG.noisy_label_flag:
        if one_label_smooth:
            real = torch.full((CONFIG.batch_size, CONFIG.height, CONFIG.width), CONFIG.real_label).to(args.device)
            fake = torch.zeros(CONFIG.batch_size, CONFIG.height, CONFIG.width).to(args.device)
        else:
            real = torch.full((CONFIG.batch_size, CONFIG.height, CONFIG.width), CONFIG.real_label).to(args.device)
            fake = torch.full((CONFIG.batch_size, CONFIG.height, CONFIG.width), CONFIG.fake_label).to(args.device)        
    else:
        real = torch.ones(CONFIG.batch_size, CONFIG.height, CONFIG.width).to(args.device)
        fake = torch.zeros(CONFIG.batch_size, CONFIG.height, CONFIG.width).to(args.device)


    """ optimizer, criterion """
    optimizer = optim.Adam(model.parameters(), lr=CONFIG.learning_rate)
    criterion = nn.CrossEntropyLoss(weight=class_weight, ignore_index=255)

    if CONFIG.train_mode == 'semi':
        criterion_bce = nn.BCELoss()    # discriminator includes sigmoid layer
        optimizer_d = optim.Adam(model_d.parameters(), lr=CONFIG.learning_rate)

    losses_full = []
    losses_semi = []
    losses_d = []
    val_iou = []
    mean_iou = []
    mean_iou_without_bg = []
    best_mean_iou = 0.0

    for epoch in tqdm.tqdm(range(CONFIG.max_epoch)):

        if CONFIG.poly_lr_decay:
            poly_lr_scheduler(optimizer, CONFIG.learning_rate, 
                epoch, max_iter=CONFIG.max_epoch, power=CONFIG.poly_power)
            
            if CONFIG.train_mode == 'semi':
                poly_lr_scheduler(optimizer_d, CONFIG.learning_rate_d, 
                    epoch, max_iter=CONFIG.max_epoch, power=CONFIG.poly_power)
        
        epoch_loss_full = 0.0
        epoch_loss_d = 0.0
        epoch_loss_semi = 0.0

        # only supervised learning
        if CONFIG.train_mode == 'full':    

            for i, sample in enumerate(train_loader_with_label):

                loss_full = full_train(model, sample, criterion, optimizer, CONFIG, args.device)
                
                epoch_loss_full += loss_full

            losses_full.append(epoch_loss_full / i)
            losses_d.append(0.0)
            losses_semi.append(0.0)

        
        # semi-supervised learning
        elif CONFIG.train_mode == 'semi':
            
            # first, adveresarial learning
            if epoch < CONFIG.adv_epoch:
                
                for i, sample in enumerate(train_loader_with_label):
                
                    loss_full, loss_d = adv_train(
                                            model, model_d, sample, criterion, criterion_bce,
                                            optimizer, optimizer_d, real, fake, CONFIG, args.device)
                
                    epoch_loss_full += loss_full
                    epoch_loss_d += loss_d
                    
                losses_full.append(epoch_loss_full / i)   # mean loss over all samples
                losses_d.append(epoch_loss_d / i)
                losses_semi.append(0.0)
                    
            # semi-supervised learning
            else:
                cnt_full = 0
                cnt_semi = 0
                
                for (sample1, sample2) in zip_longest(train_loader_with_label, train_loader_without_label):
                    
                    if sample1 is not None:
                        loss_full, loss_d = adv_train(
                                                model, model_d, sample1, criterion, criterion_bce,
                                                optimizer, optimizer_d, real, fake, CONFIG, args.device)
                        
                        epoch_loss_full += loss_full
                        epoch_loss_d += loss_d
                        cnt_full += 1

                    if sample2 is not None:
                        loss_semi = semi_train(
                                                model, model_d, sample2, criterion, criterion_bce,
                                                optimizer, optimizer_d, real, fake, CONFIG, args.device)
                        epoch_loss_semi += loss_semi
                        cnt_semi += 1

                losses_full.append(epoch_loss_full / cnt_full)   # mean loss over all samples
                losses_d.append(epoch_loss_d / cnt_full)
                losses_semi.append(epoch_loss_semi / cnt_semi)


        else:
            print('This train mode can\'t be used. Choose full or semi')
            sys.exit(1)


        # validation
        val_iou.append(eval_model(model, test_loader, CONFIG, args.device))
        mean_iou.append(val_iou[-1].mean().item())
        mean_iou_without_bg.append(val_iou[-1][1:].mean().item())

        if best_mean_iou < mean_iou[-1]:
            best_mean_iou = mean_iou[-1]
            torch.save(model.state_dict(), CONFIG.result_path + '/best_mean_iou_model.prm')
            if CONFIG.train_mode == 'semi':
                torch.save(model_d.state_dict(), CONFIG.result_path + '/best_mean_iou_model_d.prm')

        if epoch%50 == 0 and epoch != 0:
            torch.save(model.state_dict(), CONFIG.result_path + '/epoch_{}_model.prm'.format(epoch))
            if CONFIG.train_mode == 'semi':
                torch.save(model_d.state_dict(), CONFIG.result_path + '/epoch_{}_model_d.prm'.format(epoch))

        if writer is not None:
            writer.add_scalar("loss_full", losses_full[-1], epoch)
            writer.add_scalar("loss_d", losses_d[-1], epoch)
            writer.add_scalar("loss_semi", losses_semi[-1], epoch)
            writer.add_scalar("mean_iou", mean_iou[-1], epoch)
            writer.add_scalar("mean_iou_without_background", mean_iou_without_bg[-1], epoch)
            writer.add_scalars("class_IoU", {'iou of class 0': val_iou[-1][0],
                                            'iou of class 1': val_iou[-1][1],
                                            'iou of class 2': val_iou[-1][2],
                                            'iou of class 3': val_iou[-1][3],
                                            'iou of class 4': val_iou[-1][4],
                                            'iou of class 5': val_iou[-1][5],
                                            'iou of class 6': val_iou[-1][6],
                                            'iou of class 7': val_iou[-1][7]}, epoch)

        print('epoch: {}\tloss_full: {:.5f}\tloss_d: {:.5f}\tloss_semi: {:.5f}\tmean IOU: {:.3f}\tmean IOU w/ bg: {:.3f}'
            .format(epoch, losses_full[-1], losses_d[-1], losses_semi[-1], mean_iou[-1], mean_iou_without_bg[-1]))


    torch.save(model.state_dict(), CONFIG.result_path + '/final_model.prm')
    if CONFIG.train_mode == 'semi':
        torch.save(model_d.state_dict(), CONFIG.result_path + '/final_model_d.prm')
Exemplo n.º 15
0
    def __init__(self,
                 in_channels,
                 num_classes,
                 num_blocks,
                 conv_type='double',
                 residual=False,
                 depth=4,
                 activation='relu',
                 dilation=1,
                 upsample_type='upsample',
                 **kwargs):
        r"""
        
        Parameters:
        -----------
        
        in_channels: int
            The number of channels in the input image.
            
        num_classes: int
            The number of channels in the output mask. Each channel corresponds to one of the classes and contains
            a mask of probabilities for image pixels to belong to this class.
        
        num_blocks: int 
            The number of UNet blocks in the model. Must be bigger then 1.
            
        conv_type: 'single', 'double' or 'triple' (default 'double')
            Defines the number of convolutions and activations in the model's blocks. If it is 'single', there 
            are one convolutional layer with kernel_size=3, padding=1, dilation=1, followed by activation. If 
            it is 'double' or 'triple', it is once or twice complemented by convolutional layer with kernel_size=3 
            and choosen dilation with corresponding padding, followed by activation.
        
        residual: bool (default False)
            Defines if the model's convolutional blocks have residual connections.
        
        depth: int (default 4)
            Defines the depth of encoding-decoding part in UNet blocks. Must be bigger then 2.
        
        activation: 'relu', 'prelu' or 'leaky_relu' (default 'relu')
            Defines the type of the activation function in the model's convolutional blocks.
        
        dilation: int (default 1) or list
            The dilation for the model's blocks convolutional layers.
        
        upsample_type: 'upsample' or 'convtranspose'
            Defines the tipe of upsampling in the UNet blocks.
        
        channels_sequence: list
            The list of the number of out_channels for decoding part of the UNet blocks. The length of it must match the depth.
            Example: for depth=4, it can be [64, 128, 256, 512]
            If it is not set, it will be set automaticly as it discribed in the original UNet peper.
            
        Applying:
        ---------
        
        >>> model = StackUNet(3, 1, 3, activation='leaky_relu', depth=3, channels_sequence=[32, 64, 64], dilation=2)
        >>> input = torch.tensor((1, 3, 256, 256))
        >>> output = model(input)

        For getting model ditails use torchsummary:
        
        >>> from torchsummary import summary
        >>> model = StackUNet(3, 1, 3)
        >>> summary(model, input_size=(3, 256, 256))
        """
        super().__init__()

        # Check if all model parameters are set correctly.

        if num_blocks < 2:
            raise ValueError(
                "The number of blocks is expected to be bigger then 1.")

        if conv_type not in ['single', 'double', 'triple']:
            raise ValueError(
                "The type of convolution blocks is expected to be 'single', 'double' or 'triple'."
            )
        if conv_type == 'single' and residual == True:
            raise NotImplementedError(
                "For 'single' convolution blocks tupe residual is not expected to be True."
            )

        if depth < 3:
            raise ValueError(
                "The depth of encoding and decoding part of the model is expected to be bigger then 2."
            )

        if activation not in ['relu', 'prelu', 'leaky_relu']:
            raise ValueError(
                "The activation for convolution blocks is expected to be 'relu', 'prelu' or 'leaky_relu'."
            )
        if isinstance(dilation, int):
            if dilation not in [1, 2, 3]:
                raise ValueError(
                    "The dilation for convolution blocks is expected to be 1, 2 or 3."
                )
        if upsample_type not in ['upsample', 'convtranspose']:
            raise ValueError(
                "The upsample type is expected to be Upsampling or ConvTranspose."
            )

        if 'channels_sequence' in kwargs.keys():
            channels_sequence = kwargs['channels_sequence']
            if len(channels_sequence) != depth:
                raise ValueError(
                    "The length of sequence of amount of channels in decoder must match to the depth of decoding part of the model."
                )
            for val in channels_sequence:
                if not isinstance(val, int) or val < 1:
                    raise ValueError(
                        "The amount of channels must to be possitive integer.")

            for i in range(1, depth):
                if channels_sequence[i] < channels_sequence[i - 1]:
                    raise ValueError(
                        "The amount of channels is expected to increase.")

        # Define the number of out_channels in convolutional blocks in encoding part of the model.

        else:
            channels_sequence = [32]
            for i in range(depth - 1):
                if i < 1:
                    channels_sequence.append(channels_sequence[-1] * 2)
                else:
                    channels_sequence.append(channels_sequence[-1])

        # Layers initialization

        self.num_blocks = num_blocks
        out_channels = channels_sequence[0]

        self.UNet_block = UNet(in_channels,
                               out_channels,
                               conv_type=conv_type,
                               residual=residual,
                               depth=depth,
                               activation=activation,
                               dilation=dilation,
                               is_block=True,
                               upsample_type=upsample_type,
                               channels_sequence=channels_sequence)

        self.middle_conv = nn.Conv2d(out_channels + in_channels,
                                     in_channels,
                                     kernel_size=3,
                                     padding=1)
        self.last_conv = nn.Conv2d(out_channels,
                                   num_classes,
                                   kernel_size=3,
                                   padding=1)
Exemplo n.º 16
0
        model = ResNet(num_classes=10)
    elif args.model == 'IFT725Net':
        model = IFT725Net(num_classes=10)
    elif args.model == 'IFT725UNet':
        model = IFT725UNet(num_classes=4)
        args.dataset = 'acdc'

        train_set = HDF5Dataset('train',
                                hdf5_file,
                                transform=acdc_base_transform)
        test_set = HDF5Dataset('test',
                               hdf5_file,
                               transform=acdc_base_transform)
    elif args.model == 'UNet':
        if args.dataset == 'projetSession':
            model = UNet(num_classes=2, in_channels=3)
            train_set = ProjetHiver_DataSet('train',
                                            ProjetSession_file,
                                            transform=highway_transform)
            test_set = ProjetHiver_DataSet('test',
                                           ProjetSession_file,
                                           transform=highway_transform)
        elif args.dataset == 'highway':
            model = UNet(num_classes=4, in_channels=3)
            train_set = Highway_DataSet('train',
                                        highway_file,
                                        transform=highway_transform)
            test_set = Highway_DataSet('test',
                                       highway_file,
                                       transform=highway_transform)
        else:
Exemplo n.º 17
0
def main():
    args = get_cli_arguments()

    if (args.checkpoint is not None) and (not os.path.exists(args.checkpoint)):
        sys.exit('Specified checkpoint cannot be found')

    if args.mode == 'train':
        if args.dataset != 'cityscapes':
            sys.exit("Model can only be trained on cityscapes dataset")

        dataset = load_data(args.path, args.dataset, resize=~args.no_resize)

        if args.subset:
            sampler = torch.utils.data.SubsetRandomSampler(np.arange(50))
            dataloader = torch.utils.data.DataLoader(
                dataset, batch_size=args.batch_size, sampler=sampler)
        else:
            dataloader = torch.utils.data.DataLoader(
                dataset, batch_size=args.batch_size, shuffle=True)

        model = UNet(num_classes=len(datasets.Cityscapes.classes),
                     pretrained=args.pretrained)

        if not args.pretrained:
            set_parameter_required_grad(model, True)

        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
        if (args.savedir is not None) and (not os.path.exists(args.savedir)):
            os.makedirs(args.savedir)
        train(model,
              dataloader,
              criterion,
              optimizer,
              num_epochs=args.epochs,
              checkpoint_path=args.checkpoint,
              save_path=args.savedir)
        return

    if args.mode == 'test':
        dataset = load_data(args.path,
                            args.dataset,
                            resize=~args.no_resize,
                            split='val')

        if args.subset:
            sampler = torch.utils.data.SubsetRandomSampler(np.arange(50))
            dataloader = torch.utils.data.DataLoader(
                dataset, batch_size=args.batch_size, sampler=sampler)
        else:
            dataloader = torch.utils.data.DataLoader(
                dataset, batch_size=args.batch_size, shuffle=True)

        model = UNet(num_classes=len(datasets.Cityscapes.classes),
                     pretrained=args.pretrained)
        validate(model, dataloader, args.checkpoint)

    if args.mode == 'activations':
        if args.model is None:
            sys.exit("Must specify model to use with --model argument")
        dataset = load_data(args.path, args.dataset, resize=~args.no_resize)
        if args.subset:
            sampler = torch.utils.data.SubsetRandomSampler(np.arange(50))
            dataloader = torch.utils.data.DataLoader(
                dataset, batch_size=args.batch_size, sampler=sampler)
        else:
            dataloader = torch.utils.data.DataLoader(
                dataset, batch_size=args.batch_size, shuffle=True)

        if args.model == 'unet':
            model = UNet(num_classes=len(datasets.Cityscapes.classes))
            if args.checkpoint:
                checkpoint = torch.load(
                    args.checkpoint, map_location=lambda storage, loc: storage)
                model.load_state_dict(checkpoint['model_state_dict'])
            else:
                print(
                    "NOTE: Getting activations for untrained network. Specified a pretrained model with the "
                    "--checkpoint argument.")
        elif args.model == 'vggmod':
            model = VGGmod()
        else:
            model = UNet(num_classes=len(datasets.Cityscapes.classes))
            if args.checkpoint:
                checkpoint = torch.load(args.checkpoint)
                model.load_state_dict(checkpoint['model_state_dict'])
            set_parameter_required_grad(model, True)
            retrieve_activations(model, dataloader, args.dataset)
            model = VGGmod()

        set_parameter_required_grad(model, True)

        retrieve_activations(model, dataloader, args.dataset)

    if args.mode == 'view_activations':
        file_1 = os.path.join(args.path, 'VGGmod_activations')
        file_2 = os.path.join(args.path, 'UNet_activations_matched')
        if not os.path.exists(file_1) or not os.path.exists(file_2):
            exit(
                "Could not load activations from " + args.path +
                ". If you have not generated activations for both UNet "
                "and VGG11, run instead with the \"--mode activations\" parameter."
            )

        activs1, activs2 = load_activations(file_1, file_2)
        visualize_batch(activs1,
                        activs2,
                        batch_num=args.batch_num,
                        start_layer=args.start_layer,
                        stop_layer=args.stop_layer)

    if args.mode == 'compare_activations':
        file_1 = os.path.join(args.path, 'VGGmod_activations')
        file_2 = os.path.join(args.path, 'UNet_activations')
        match_channels(file_1, file_2, args.type)

    if args.mode == 'view_max_activating':
        channel_vis_driver(args.model, args.checkpoint, args.path,
                           args.dataset, args.conv_layer, args.channels,
                           args.img_size, args.upscale_steps,
                           args.upscale_factor, args.learning_rate,
                           args.opt_steps, args.grid, args.path, args.verbose)
def test(gpu, args):
    print("Starting...")
    print("Using {} percent of data for testing".format(100 -
                                                        args.training_frac *
                                                        100))
    torch.autograd.set_detect_anomaly(True)
    #torch.manual_seed(args.torch_seed)
    #torch.cuda.manual_seed(args.cuda_seed)
    torch.cuda.set_device(gpu)

    DATA_ROOT = args.data_root
    NUM_IMAGES = args.num_test_images
    CHKPNT_PTH = args.checkpoint_path

    if args.model_name == 'dss':
        model = DSS_Net(args, n_channels=3, n_classes=1, bilinear=True)
        loss = FocalLoss()
    elif args.model_name == 'unet':
        model = UNet(args)
        loss = nn.BCELoss()
    else:
        raise NotImplementedError

    state_dict = torch.load(CHKPNT_PTH)
    new_state_dict = copy_state_dict(state_dict)  #OrderedDict()
    # for k, v in state_dict.items():
    #     name = k[7:]
    #     names = name.strip().split('.')
    #     if names[1] == 'inc':
    #         names[1] = 'conv1'
    #     name = '.'.join(names)
    #     # print(names)
    #     new_state_dict[name] = v

    # print("Expected values:", model.state_dict().keys())
    model.load_state_dict(new_state_dict)
    model.cuda(gpu)
    model.eval()

    if args.test_loader == 'annotated':
        dataset = SmokeDataset(dataset_limit=NUM_IMAGES,
                               training_frac=args.training_frac)
        # dataset = SyntheticSmokeTrain(args,dataset_limit=50)
    else:
        dataset = SimpleSmokeVal(args=args,
                                 data_root=DATA_ROOT,
                                 dataset_limit=NUM_IMAGES)
    dataloader = DataLoader(dataset,
                            1,
                            shuffle=True,
                            num_workers=4,
                            pin_memory=True)  #, sampler=train_sampler)

    # if not args.no_logging and gpu == 1:
    if not os.path.isdir(args.log_dir):
        os.mkdir(args.log_dir)
    log_dir = os.path.join(args.log_dir, args.exp_dir)
    if not os.path.isdir(log_dir):
        os.mkdir(log_dir)
    if args.exp_name == "":
        exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d")
    else:
        exp_name = args.exp_name
        log_dir = os.path.join(log_dir, exp_name)
    writer = SummaryWriter(log_dir)

    iou_sum = 0
    iou_count = 0
    iou_ = 0
    for idx, data in enumerate(dataloader):
        if args.test_loader == 'annotated':
            out_img, iou_ = val_step_with_loss(data, model)
            iou_sum += iou_
            iou_count += 1
            writer.add_images('true_mask', data['target_mask'] > 0, idx)
        else:
            out_img = val_step(data, model)
        writer.add_images('input_img', data['input_img'], idx)
        writer.add_images('pred_mask', out_img, idx)
        writer.add_scalar(f'accuracy/test', iou_, idx)
        writer.flush()
        # print("Step: {}/{}: IOU: {}".format(idx,len(dataloader), iou_))
        if idx > len(dataloader):
            break
    if iou_count > 0:
        iou = iou_sum / iou_count
        writer.add_scalar(f'mean_accuracy/test', iou)
        print("Mean IOU: ", iou)
    print("Done")