Exemplo n.º 1
0
    def reshape_decode(self, data, shape):
        if self.float_data:  # @TODO(tzaman): this is LMDB specific - Make generic!
            data = tf.reshape(data, shape)
            data = digits.chw_to_hwc(data)
        else:
            # Decode image of any time option might come: https://github.com/tensorflow/tensorflow/issues/4009
            # Distinguish between mime types
            if self.data_encoded:
                if self.data_mime == 'image/png':
                    data = tf.image.decode_png(data, dtype=self.image_dtype, name='image_decoder')
                elif self.data_mime == 'image/jpeg':
                    data = tf.image.decode_jpeg(data, name='image_decoder')
                else:
                    logging.error('Unsupported mime type (%s); cannot be decoded' % (self.data_mime))
                    exit(-1)
            else:
                if self.backend == 'lmdb':
                    data = tf.decode_raw(data, self.image_dtype, name='raw_decoder')

                # if data is in CHW, set the shape and convert to HWC
                if self.unencoded_data_format == 'chw':
                    data = tf.reshape(data, [shape[0], shape[1], shape[2]])
                    data = digits.chw_to_hwc(data)
                else:  # 'hwc'
                    data = tf.reshape(data, shape)

                if (self.channels == 3) and self.unencoded_channel_scheme == 'bgr':
                    data = digits.bgr_to_rgb(data)

            # Convert to float
            data = tf.to_float(data)
            # data = tf.image.convert_image_dtype(data, tf.float32) # normalize to [0:1) range
        return data
Exemplo n.º 2
0
def train(train_loader, model, criterion, optimizer, epoch, result_dir):
    losses = AverageMeter()
    model.train()

    for ind, (noise_img, origin_img) in enumerate(train_loader):
        st = time.time()

        input_var = noise_img.cuda()
        target_var = origin_img.cuda()

        output = model(input_var)
        loss = criterion(output, target_var)

        losses.update(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print('[{0}][{1}]\t'
              'lr: {lr:.5f}\t'
              'Loss: {loss.val:.4f} ({loss.avg:.4f})\t'
              'Time: {time:.3f}'.format(epoch,
                                        ind,
                                        lr=optimizer.param_groups[-1]['lr'],
                                        loss=losses,
                                        time=time.time() - st))

        if epoch % args.save_freq == 0:
            if not os.path.isdir(os.path.join(result_dir, '%04d' % epoch)):
                os.makedirs(os.path.join(result_dir, '%04d' % epoch))

            origin_np = origin_img.numpy()
            noise_np = noise_img.numpy()
            output_np = output.cpu().detach().numpy()

            origin_np_img = chw_to_hwc(origin_np[0])
            noise_np_img = chw_to_hwc(noise_np[0])
            output_img = chw_to_hwc(np.clip(output_np[0], 0, 1))

            temp = np.concatenate((origin_np_img, noise_np_img, output_img),
                                  axis=1)
            io.imsave(
                os.path.join(result_dir, '%04d/train_%d.jpg' % (epoch, ind)),
                np.uint8(temp * 255))
Exemplo n.º 3
0
    def inferenceCMP(self, names):
        imgs = []
        for name in names:
            img = cv2.imread(name, -1)
            # img = np.load(name)
            img = img / 4096
            # img = img / 13132

            # img = img[:512, :512]
            # img = (img - img.min()) / (img.max() - img.min())

            imgs.append((img).astype("float32")[None, :, :])
        img = np.concatenate(imgs, axis=0)
        img = torch.from_numpy(img).unsqueeze(0)
        img = img.cuda()
        output = self.MODEL(img)
        acm = output[0].cpu().detach().numpy()
        return chw_to_hwc(acm)
Exemplo n.º 4
0
parser.add_argument('input_filename', type=str)
parser.add_argument('output_filename', type=str)
args = parser.parse_args()

save_dir = './save_model/'

model = Network()
model.cuda()
model = nn.DataParallel(model)

model.eval()

if os.path.exists(os.path.join(save_dir, 'checkpoint.pth.tar')):
    # load existing model
    model_info = torch.load(os.path.join(save_dir, 'checkpoint.pth.tar'))
    model.load_state_dict(model_info['state_dict'])
else:
    print('Error: no trained model detected!')
    exit(1)

input_image = read_img(args.input_filename)
input_var = torch.from_numpy(hwc_to_chw(input_image)).unsqueeze(0).cuda()

with torch.no_grad():
    _, output = model(input_var)

output_image = chw_to_hwc(output[0, ...].cpu().numpy())
output_image = np.uint8(np.round(np.clip(output_image, 0, 1) *
                                 255.))[:, :, ::-1]

cv2.imwrite(args.output_filename, output_image)