예제 #1
0
def make_model():
    # Making all the modules of the model architecture
    i_s = 336
    encoder_inp_shape = (i_s,i_s,3)
    enc = encoder(encoder_inp_shape)
    hg_inp_shape_1 = (i_s // 4, i_s // 4, 512)
    hg1 = hourglass(hg_inp_shape_1)
    hg_inp_shape_2 = (i_s // 4, i_s // 4, 256)
    hg2 = hourglass(hg_inp_shape_2)
    decoder_inp_shape = (i_s // 4, i_s // 4, 256)
    dec = decoder(decoder_inp_shape)
    proSR = net2(encoder_inp_shape)

    # Making the graph by connecting all the moduless of the model architecture
    # Each of this model can be seen as a layer now.
    input_tensor_1 = Input(encoder_inp_shape)
    input_tensor_2 = Input(encoder_inp_shape)
    part1 = enc(input_tensor_1)
    part2 = hg1(part1)
    part3 = hg2(part2)
    part4 = dec(part3)
    part5 = proSR(input_tensor_2)
    output = Add()([part4, part5])
    model = Model([input_tensor_1, input_tensor_2], output)
    model.compile(loss=root_mean_sq_GxGy, optimizer = RMSprop())

    with open('hourglass_sr_t1_t2.txt', 'w') as f:
        with redirect_stdout(f):
            model.summary()

    return model
예제 #2
0
    def _build_network(self,
                       inputs,
                       datas=None,
                       n_stacks=1,
                       n_channels=FLAGS.n_landmarks,
                       is_training=True):
        # gt_heatmap, gt_lms, mask_index, gt_mask = datas

        batch_size = tf.shape(inputs)[0]
        height = tf.shape(inputs)[1]
        width = tf.shape(inputs)[2]

        net = inputs

        # net = models.StackedHourglass(net, FLAGS.n_landmarks)
        # states.append(net)
        # net = tf.stop_gradient(net)
        # net *= gt_mask[:,None,None,:]
        # net = tf.concat([inputs,net], 3)
        # net = models.StackedHourglass(net, FLAGS.n_landmarks)
        # states.append(net)

        batch_size = tf.shape(inputs)[0]
        height = tf.shape(inputs)[1]
        width = tf.shape(inputs)[2]
        channels = tf.shape(inputs)[3]

        states = []

        with slim.arg_scope([slim.batch_norm, slim.layers.dropout],
                            is_training=is_training):
            with slim.arg_scope(models.hourglass_arg_scope_tf()):
                net = None
                # stacked hourglass
                for i in range(n_stacks):
                    with tf.variable_scope('stack_%02d' % i):
                        if net is not None:
                            net = tf.concat((inputs, net), 3)
                        else:
                            net = inputs

                        net, _ = models.hourglass(
                            net,
                            regression_channels=n_channels,
                            classification_channels=0,
                            deconv='transpose',
                            bottleneck='bottleneck_inception')

                        states.append(net)

                prediction = net
                return prediction, states
def get_model(model_path, model_type):
    """

    :param model_path:
    :param model_type: 'UNet', 'UNet11', 'UNet16', 'AlbuNet34'
    :return:
    """

    num_classes = 1

    if model_type == 'UNet11':
        model = UNet11(num_classes=num_classes)
    elif model_type == 'UNet16':
        model = UNet16(num_classes=num_classes)
    elif model_type == 'AlbuNet34':
        model = AlbuNet34(num_classes=num_classes)
    elif model_type == 'MDeNet':
        print('Mine MDeNet..................')
        model = MDeNet(num_classes=num_classes)
    elif model_type == 'EncDec':
        print('Mine EncDec..................')
        model = EncDec(num_classes=num_classes)
    elif model_type == 'hourglass':
        model = hourglass(num_classes=num_classes)
    elif model_type == 'MDeNetplus':
        print('load MDeNetplus..................')
        model = MDeNetplus(num_classes=num_classes)
    elif model_type == 'UNet':
        model = UNet(num_classes=num_classes)
    else:
        print('I am here')
        model = UNet(num_classes=num_classes)

    state = torch.load(str(model_path))
    state = {
        key.replace('module.', ''): value
        for key, value in state['model'].items()
    }
    model.load_state_dict(state)

    if torch.cuda.is_available():
        return model.cuda()

    model.eval()

    return model
예제 #4
0
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--jaccard-weight', default=0.3, type=float)
    arg('--device-ids', type=str, default='0', help='For example 0,1 to run on two GPUs')
    arg('--fold', type=int, help='fold', default=0)
    arg('--root', default='runs/debug', help='checkpoint root')
    arg('--batch-size', type=int, default=1)
    arg('--limit', type=int, default=10000, help='number of images in epoch')
    arg('--n-epochs', type=int, default=100)
    arg('--lr', type=float, default=0.001)
    arg('--workers', type=int, default=12)
    arg('--model', type=str, default='UNet', choices=['UNet', 'UNet11', 'LinkNet34', 'UNet16', 'AlbuNet34', 'MDeNet', 'EncDec', 'hourglass', 'MDeNetplus'])

    args = parser.parse_args()
    root = Path(args.root)
    root.mkdir(exist_ok=True, parents=True)

    num_classes = 1
    if args.model == 'UNet':
        model = UNet(num_classes=num_classes)
    elif args.model == 'UNet11':
        model = UNet11(num_classes=num_classes, pretrained=True)
    elif args.model == 'UNet16':
        model = UNet16(num_classes=num_classes, pretrained=True)
    elif args.model == 'MDeNet':
        print('Mine MDeNet..................')
        model = MDeNet(num_classes=num_classes, pretrained=True)
    elif args.model == 'MDeNetplus':
        print('load MDeNetplus..................')
        model = MDeNetplus(num_classes=num_classes, pretrained=True)
    elif args.model == 'EncDec':
        print('Mine EncDec..................')
        model = EncDec(num_classes=num_classes, pretrained=True)
    elif args.model == 'GAN':
        model = GAN(num_classes=num_classes, pretrained=True)
    elif args.model == 'AlbuNet34':
        model = AlbuNet34(num_classes=num_classes, pretrained=False)
    elif args.model == 'hourglass':
        model = hourglass(num_classes=num_classes, pretrained=True) 
    else:
        model = UNet(num_classes=num_classes, input_channels=3)

    if torch.cuda.is_available():
        if args.device_ids:
            device_ids = list(map(int, args.device_ids.split(',')))
        else:
            device_ids = None
        model = nn.DataParallel(model).cuda()   #  nn.DataParallel(model, device_ids=device_ids).cuda()
    
    cudnn.benchmark = True
    
    def make_loader(file_names, shuffle=False, transform=None, limit=None):
        return DataLoader(
            dataset=Polyp(file_names, transform=transform, limit=limit),
            shuffle=shuffle,
            num_workers=args.workers,
            batch_size=args.batch_size,
            pin_memory=torch.cuda.is_available()
        )

    train_file_names, val_file_names = get_split(args.fold)

    print('num train = {}, num_val = {}'.format(len(train_file_names), len(val_file_names)))
    
    train_transform = DualCompose([
        CropCVC612(),
        img_resize(512),
        HorizontalFlip(),
        VerticalFlip(),
        Rotate(),
        Rescale(), 
        Zoomin(),
        ImageOnly(RandomHueSaturationValue()),
        ImageOnly(Normalize())
    ])

    train_loader = make_loader(train_file_names, shuffle=True, transform=train_transform, limit=args.limit)

    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))

    utils.train(
        args=args,
        model=model,
        train_loader=train_loader,
        fold=args.fold
    )
예제 #5
0
    'stages': 3,
    'dilation': [1],
    'pooling': [
        1,
        1,
        1,
    ],
    'downsample': 8,
    'argmax': 4,
    'pretrained': False
}

student_cpm = cpm_vgg16(student_cpm_config, 68)
#print('CPM:\n{:}'.format(student_cpm))
FLOPs, _ = get_model_infos(student_cpm, (1, 3, 64, 64))
print('CPM-Parameters : {:} MB, FLOP : {:} MB.'.format(
    count_network_param(student_cpm) / 1e6, FLOPs))

student_hg_config = {
    'nStack': 4,
    'nModules': 2,
    'nFeats': 256,
    'downsample': 4
}

student_hg = hourglass(student_hg_config, 68)
FLOPs, _ = get_model_infos(student_hg, (1, 3, 64, 64))
#print('CPM:\n{:}'.format(student_cpm))
print('HG--Parameters : {:} MB, FLOP : {:} MB.'.format(
    count_network_param(student_hg) / 1e6, FLOPs))