os.makedirs(save_path)
        
        thresd = cfg.TEST.thresd
        model = FusionNet(input_nc=cfg.TRAIN.input_nc,output_nc=cfg.TRAIN.output_nc,ngf=cfg.TRAIN.ngf)
        ckpt = 'model.ckpt'
        ckpt_path = os.path.join(model_path, ckpt)
        checkpoint = torch.load(ckpt_path)

        new_state_dict = OrderedDict()
        state_dict = checkpoint['model_weights']
        for k, v in state_dict.items():
            name = k[7:] # remove module.
            # name = k
            new_state_dict[name] = v

        model.load_state_dict(new_state_dict)
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)

        PAD = cfg.TRAIN.pad
        img_list = os.listdir(base_path)
        for f_img in img_list:
            print('Inference: ' + f_img, end=' ')
            # raw = np.asarray(Image.open(os.path.join(base_path, f_img)).convert('L'))
            raw = np.asarray(Image.open(os.path.join(base_path, f_img)))
            raw = raw.transpose(2, 0, 1)
            if cfg.TRAIN.track == 'complex':
                if raw.shape[0] == 9959 or raw.shape[0] == 9958:
                    raw_ = np.zeros((10240,10240), dtype=np.uint8)
                    raw_[141:141+9959, 141:141+9958] = raw
                    raw = raw_
Esempio n. 2
0
            model = nn.DataParallel(model)
        else:
            raise AttributeError(
                'Batch size (%d) cannot be equally divided by GPU number (%d)'
                % (cfg.TRAIN.batch_size, cuda_count))
    else:
        print('a single GPU ... ', end='', flush=True)

    ckpt = 'model.ckpt'
    ckpt_path = os.path.join(model_path, ckpt)
    if os.path.isfile(ckpt_path):
        checkpoint = torch.load(ckpt_path)
        iters = checkpoint['current_iter']
        avg_f1_tmp = checkpoint['valid_result']
        if cuda_count > 1:
            model.load_state_dict(checkpoint['model_weights'])
        else:
            new_state_dict = OrderedDict()
            state_dict = checkpoint['model_weights']
            for k, v in state_dict.items():
                name = k[7:]  # remove module.
                new_state_dict[name] = v
            model.load_state_dict(new_state_dict)
    else:
        raise AttributeError('No checkpoint found at %s' % model_path)

    #### save model
    ckpt2 = 'model_simple.ckpt'
    ckpt_path2 = os.path.join(model_path, ckpt2)
    states = {
        'current_iter': iters,