def train(net, epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()

        if args.fixed:
            net = utils.quantize(net, args.pprec)

        optimizer.step()

        train_loss += loss.data.item()
        _, predicted = torch.max(outputs.data, 1)
        total += float(targets.size(0))
        correct += predicted.eq(targets.data).cpu().sum().type(
            torch.FloatTensor)

        progress_bar(
            batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (train_loss /
             (batch_idx + 1), 100. * correct / total, correct, total))

    if args.fixed:
        net = utils.quantize(net, args.pprec)
示例#2
0
    def _generateSimulationPayload(self):
        """
            block payload is as follows

            block:
            {
                position: (x, y), - 14 bit integers
                angle: r,         - 11 bit integer
                blockIndex: n     - 9  bit integer  (max 256 entities)
            }

            total: 28 + 11 + 9 = 48 bits per block
            over ~200 blocks is 10000 bits per frame = 10 kb
            @ 60fps = 600 kb/s

            most blocks probably won't be moving (worst case maybe 100 at a time; mobs + block interactions)
            this brings it down to roughly 300 kb/s on average.
            
            with compression this should be much better, ~100 kb/s

            for this server test, there are only max like 10 blocks, so we keep it down to 500 bits per frame,
            or about 30 kb/s
        
        
            each block is 48 bits in total, so it's 6 bytes
        """

        POS_SIZE = 14
        ANGLE_SIZE = 11
        IDX_SIZE = 9

        payload = 0
        for idx, b in enumerate(self.blocks):
            acc = 0

            pos = b.body.position
            angle = b.body.angle

            # quantize to 14 bits signed integer
            # 14 bits is 16387
            # field is 2000 wide, gives precision of about 0.12, which is totally fine for what we're doing.
            prec = 2000 / (2**POS_SIZE - 1)
            invPrec = 1 / prec
            x, y = utils.quantize(x, -1000, 1000, POS_SIZE), utils.quantize(
                y, -1000, 1000, POS_SIZE)

            acc = x + (y << POS_SIZE)

            # quantize angle to 11 bit integer
            # convert angle to positive number and remove redundant angle
            angle %= np.pi * 2
            r = utils.quantize(angle, 0, np.pi * 2, ANGLE_SIZE)

            acc += r << (POS_SIZE + POS_SIZE)
            acc += idx << (POS_SIZE + POS_SIZE + ANGLE_SIZE)

            payload += acc << (idx *
                               (POS_SIZE + POS_SIZE + ANGLE_SIZE + IDX_SIZE))
示例#3
0
文件: train.py 项目: zxlation/FC2N
def valid_one_scale(model, dataset, scale, global_step):
    num_images = len(dataset)
    model_psnr = 0.0
    cubic_psnr = 0.0
    total_loss = 0.0
    print("        [ ====== SR X%d ====== ]" % scale)
    for idx_image in range(num_images):
        print('  -- image %d/%d...' % (idx_image + 1, num_images))
                
        # generate batch, batch_size = 1 when validation
        imLR, imGT, im_name = dataset[idx_image]
        
        # run model
        inp_batch = imLR[np.newaxis, ...].astype(np.float32)
        imSR = model.chop_forward(inp_batch, scale)
        imSR = quantize(imSR)
        
        # use this to approximate bicubic interpolation
        cuSR = resize(imLR, imGT.shape, order = 3, mode = "symmetric", preserve_range = True)
        cuSR = quantize(cuSR)
        
        # calc model loss
        model_loss = np.mean(np.abs(np.float32(imGT) - np.float32(imSR)))
        total_loss += model_loss
        
        # model psnr and ssim
        mo_psnr = calc_test_psnr(imGT, imSR, scale)
        cu_psnr = calc_test_psnr(imGT, cuSR, scale)
                
        model_psnr += mo_psnr
        cubic_psnr += cu_psnr
        
    # calculate the average PSNR over the whole validation set
    total_loss = total_loss / num_images
    model_psnr = model_psnr / num_images
    cubic_psnr = cubic_psnr / num_images
    
    # training logs
    target_dir = os.path.join(record_dir, model.name, "X%d" % scale)
    if not os.path.exists(target_dir): os.makedirs(target_dir)
    tarname = os.path.join(target_dir, "valid_record.txt")
    with open(tarname, "a") as file:
        format_str = "%d\t%.4f\t%.2f\t%.2f\n"
        file.write(format_str % (int(global_step), total_loss, model_psnr, cubic_psnr))
                    
    model_gain = model_psnr - cubic_psnr
    color_psnr = "grey" if model_gain < 0 else "red"
    
    formatstr = "  model_psnr = %.2f\tcubic_psnr = %.2f" % (model_psnr, cubic_psnr)
    print(colored(formatstr, 'white', attrs = ['bold']))
    formatstr = "  total_loss = %.2f\tmodel_gain = %.2f" % (total_loss, model_gain)
    print(colored(formatstr, color_psnr, attrs = ["bold"]))
    
    return model_psnr
示例#4
0
def main():
    ## data
    print('Loading data...')
    test_hr_path = os.path.join('data/', dataset)
    if dataset == 'Set5':
        ext = '*.bmp'
    else:
        ext = '*.png'
    hr_paths = sorted(glob.glob(os.path.join(test_hr_path, ext)))

    ## model
    print('Loading model...')
    tensor_lr = tf.placeholder('float32', [1, None, None, 3], name='tensor_lr')
    tensor_b = tf.placeholder('float32', [1, None, None, 3], name='tensor_b')

    tensor_sr = IDN(tensor_lr, tensor_b, scale)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess, model_path)

    ## result
    save_path = os.path.join(saved_path, dataset + '/x' + str(scale))
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    psnr_score = 0
    for i, _ in enumerate(hr_paths):
        print('processing image %d' % (i + 1))
        img_hr = utils.modcrop(misc.imread(hr_paths[i]), scale)
        img_lr = utils.downsample_fn(img_hr, scale=scale)
        img_b = utils.upsample_fn(img_lr, scale=scale)
        [lr, b] = utils.datatype([img_lr, img_b])
        lr = lr[np.newaxis, :, :, :]
        b = b[np.newaxis, :, :, :]
        [sr] = sess.run([tensor_sr], {tensor_lr: lr, tensor_b: b})
        sr = utils.quantize(np.squeeze(sr))
        img_sr = utils.shave(sr, scale)
        img_hr = utils.shave(img_hr, scale)
        if not rgb:
            img_pre = utils.quantize(sc.rgb2ycbcr(img_sr)[:, :, 0])
            img_label = utils.quantize(sc.rgb2ycbcr(img_hr)[:, :, 0])
        else:
            img_pre = img_sr
            img_label = img_hr
        psnr_score += utils.compute_psnr(img_pre, img_label)
        misc.imsave(os.path.join(save_path, os.path.basename(hr_paths[i])), sr)

    print('Average PSNR: %.4f' % (psnr_score / len(hr_paths)))
    print('Finish')
示例#5
0
def updategui():
    global psnrs
    global out
    out=out
    gt2=gt # dont touch it! XD
    apptr.PSNR.setText(f"psnr: {round(psnrs.avg)}")
    apptr.ITER.setText(f"iterations: {it}")
    q_im = utils.quantize(out[0].data.mul(255))
    q_im = q_im.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
    q_im = cv2.resize(q_im, (256,256), interpolation = cv2.INTER_AREA)
    apptr.label_2.setPixmap( QPixmap(QImage(q_im.data, 256, 256, 768, QImage.Format_RGB888))  )
    q_im = utils.quantize(gt2[0].data.mul(255))
    q_im = q_im.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
    q_im = cv2.resize(q_im, (256,256), interpolation = cv2.INTER_AREA)
    apptr.label_3.setPixmap( QPixmap(QImage(q_im.data, 256, 256, 768, QImage.Format_RGB888))  )
示例#6
0
    def forward(self, x):
        x = self.E_Conv_1(x)
        x = self.E_PReLU_1(x)
        x = self.E_Conv_2(x)
        x = self.E_PReLU_2(x)
        x = self.E_Conv_3(x)
        x = self.E_PReLU_3(x)
        x = self.E_Res(x)
        x = self.E_Conv_4(x)
        x = self.E_Conv_5(x)
        x = self.E_Conv_6(x)

        if self.prune:
            x = self.Pruner(x, self.threshold)
        x = quantize(x)
        #print(self.D_SubPix_00)
        y = self.D_SubPix_00(x)
        y = self.D_SubPix_0(y)
        y = self.D_SubPix_1(y)
        y = self.D_PReLU_1(y)
        y = self.D_Res(y)
        y = self.D_SubPix_2(y)
        y = self.D_PReLU_2(y)
        y = self.D_SubPix_3(y)
        y = self.D_PReLU_3(y)
        y = self.D_SubPix_4(y)
        y = (self.tanh(y) + 1) / 2

        return y, x
示例#7
0
    def do_pca(self, input, max_quantized_value=2.0, min_quantized_value=-2.0):
        reduce_dim = 1024
        load_file = open("model_pca_tag_category_100w.pickle", "rb")
        mean_block3 = pickle.load(load_file)
        component_block3 = pickle.load(load_file)
        component_block3 = component_block3[:, 0:reduce_dim]
        singular_values_ = pickle.load(load_file)
        singular_block3 = tf.constant(singular_values_,
                                      dtype=tf.float32,
                                      name='pac_singular_block3')
        mean_block3 = tf.constant(mean_block3,
                                  dtype=tf.float32,
                                  name='pac_mean_block3')
        component_block3 = tf.constant(component_block3,
                                       dtype=tf.float32,
                                       name='pac_component_block3')
        res_fea_pca = tf.matmul(
            input - mean_block3,
            component_block3) / tf.sqrt(singular_block3[0:reduce_dim] + 1e-4)

        res_fea = utils.quantize(res_fea_pca,
                                 max_quantized_value=max_quantized_value,
                                 min_quantized_value=min_quantized_value)
        res_fea = utils.Dequantize(res_fea,
                                   max_quantized_value=max_quantized_value,
                                   min_quantized_value=min_quantized_value)
        # res_fea_pca = tf.reshape(res_fea_pca, [-1, frams, reduce_dim])
        # res_fea = tf.reshape(res_fea_pca, tf.shape(res_fea_pca))
        return res_fea
def retrain(net, epoch, mask_prune, lr):
    print('\nEpoch: %d' % epoch)
    global best_acc

    net.train()
    train_loss = 0
    total = 0
    correct = 0

    optimizer = optim.SGD(net.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=5e-4)

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()

        net = utils.pruneNetwork(net, mask_prune)

        if args.fixed:
            net = utils.quantize(net, args.pprec)

        optimizer.step()

        train_loss += loss.data.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += float(predicted.eq(targets.data).cpu().sum())

        progress_bar(
            batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (train_loss /
             (batch_idx + 1), 100. * correct / total, correct, total))
        acc = 100. * correct / total

    net = utils.pruneNetwork(net, mask_prune)

    if args.fixed:
        net = utils.quantize(net, args.pprec)
示例#9
0
    def run_inference(self, image):
        """
        Performs the inferencing and OCR

        :param image: An image(numpy array) to perform inferencing on
        :return: List of text detected, returns empty list of no object is detected
        """
        input_shape = self.get_input_shape()
        preprocess_function = lambda img: quantize((img / 127.5) - 1, 128, 127)
        out_list = self.invoke(image, preprocess_function)

        # model outputs 3 tensors now, normalized to between 0 and 1
        score_map = dequantize(out_list[0], 128, 127)
        geo_loc_map = dequantize(out_list[1], 128, 127)
        geo_angle = dequantize(out_list[2], 128, 127)
        score_map = (score_map + 1) * 0.5
        geo_loc_map = (geo_loc_map + 1) * 256
        geo_angle = 0.7853981633974483 * geo_angle
        geo_map = np.concatenate((geo_loc_map, geo_angle), axis=3)

        boxes = text_detection(score_map=score_map, geo_map=geo_map)

        if boxes is not None:
            boxes = boxes[:, :8].reshape((-1, 4, 2))
            boxes[:, :, 0] /= input_shape[1] / image.shape[0]
            boxes[:, :, 1] /= input_shape[2] / image.shape[1]

            output_text = []
            for box in boxes:
                box = sort_poly(box.astype(np.int32))
                if np.linalg.norm(box[0] -
                                  box[1]) < 5 or np.linalg.norm(box[3] -
                                                                box[0]) < 5:
                    continue
                (x_max, x_min), (y_max,
                                 y_min) = xy_maxmin(box[:, 0], box[:, 1])

                if x_max > image.shape[0]:
                    x_max = image.shape[0]
                if x_min < 0:
                    x_min = 0
                if y_max > image.shape[1]:
                    y_max = image.shape[1]
                if y_min < 0:
                    y_min = 0

                cv2.polylines(image, [box.astype(np.int32).reshape(-1, 1, 2)],
                              True,
                              color=(255, 255, 0),
                              thickness=2)

                sub_img = image[y_min:y_max, x_min:x_max]
                txt = get_text(sub_img)
                if txt != '':
                    output_text.append(get_text(sub_img))
            return output_text
        return []
示例#10
0
    def eval(self):
        with torch.no_grad():
            output = self.model(self.eval_input)
            output = quantize(output, self.opt.rgb_range)

        # output = torch.clamp(output, 0, 1)

        return {'input': self.eval_input[0], 'output': output[0],
                'target': self.eval_target[0]}
示例#11
0
    def train(self):
        self.output = self.model(self.input, self.input_bicu)

        loss = self.criterion(self.output, self.target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.output = quantize(self.output, self.opt.rgb_range)
        return loss
示例#12
0
    def train(self):
        x_out = F.interpolate(self.input,
                              scale_factor=self.sr_factor,
                              mode='bicubic')
        weights = np.linspace(0, 1, self.opt.recur_step)
        weight2 = 0.1 + (self.epoch // 20) * 0.2
        for i in range(self.opt.recur_step):
            SR_out, deblock_out = self.model(self.input, x_out.detach())
            loss = (self.criterion(deblock_out, self.target_down) *
                    (1 - weight2) +
                    self.criterion(SR_out, self.target) * weight2) * weights[i]
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            x_out = SR_out

        self.SR_out = quantize(SR_out, self.opt.rgb_range)
        self.deblock_out = quantize(deblock_out, self.opt.rgb_range)

        return loss
示例#13
0
def main():
    log_dir = path.join('sample', 'log')
    shutil.rmtree(log_dir, ignore_errors=True)
    writer = tensorboard.SummaryWriter(log_dir=log_dir)

    img_input = imageio.imread(path.join('sample', 'butterfly_lr.png'))
    img_target = imageio.imread(path.join('sample', 'butterfly.png'))

    img_input, img_target = pp.to_tensor(img_input, img_target)
    writer.add_image('img_input', utils.quantize(img_input))
    writer.add_image('img_target', utils.quantize(img_target))

    dir_input = '[your_path]'
    dir_target = '[your_path]'
    data_test = backbone.RestorationData(dir_input, dir_target, method='direct')
    x, y = data_test[0]
    writer.add_image('patch_input', utils.quantize(x)) 
    writer.add_image('patch_target', utils.quantize(y)) 
    # Bug?
    writer.add_image('patch_target', utils.quantize(y)) 
示例#14
0
def test(net):
    global glob_gau
    global glob_blur
    global best_acc
    glob_blur = 1
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    mask_channel = torch.load('mask_null.dat')
    mask_channel = utils.setMask(utils.setMask(mask_channel, 3, 1), 4, 0)
    if args.mode > 0:
        net = utils.netMaskMul(net, mask_channel)
        net = utils.addNetwork(net, net2)
    if args.fixed == 1:
        net = utils.quantize(net, args.pprec)
    for batch_idx, (inputs, targets) in enumerate(test_loader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = Variable(inputs), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)

        test_loss += loss.data.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += float(predicted.eq(targets.data).cpu().sum())

        progress_bar(
            batch_idx, len(test_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (test_loss /
             (batch_idx + 1), 100. * correct / total, correct, total))

    # Save checkpoint.
    acc = 100. * correct / total
    if acc > best_acc:

        state = {
            'net': net.module if use_cuda else net,
            'acc': acc,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        if args.mode == 0:
            pass
        else:
            print('Saving..')
            torch.save(state, './checkpoint/ckpt_20190802_half_clean_B1.t0')
        best_acc = acc

    return acc
示例#15
0
文件: train.py 项目: v4lkyri3/PPON
def test():
    avg_psnr = 0

    for batch in testing_data_loader:
        input, target = batch[0].detach(), batch[1].detach()
        model.feed_data([input], need_HR=False)
        model.test()
        pre = model.get_current_visuals(need_HR=False)
        sr_img = utils.tensor2np(pre['SR'].data)
        gt_img = utils.tensor2np(target.data[0])
        crop_size = args.scale
        cropped_sr_img = utils.shave(sr_img, crop_size)
        cropped_gt_img = utils.shave(gt_img, crop_size)
        if is_y is True:
            im_label = utils.quantize(sc.rgb2ycbcr(cropped_gt_img)[:, :, 0])
            im_pre = utils.quantize(sc.rgb2ycbcr(cropped_sr_img)[:, :, 0])
        else:
            im_label = cropped_gt_img
            im_pre = cropped_sr_img
        avg_psnr += utils.compute_psnr(im_pre, im_label)

    print("===> Valid. psnr: {:.4f}".format(avg_psnr /
                                            len(testing_data_loader)))
示例#16
0
    def eval(self):
        SR_out_list = []
        deblock_out_list = []
        with torch.no_grad():
            x_out = F.interpolate(self.eval_input,
                                  scale_factor=self.sr_factor,
                                  mode='bicubic')
            for i in range(self.opt.recur_step):
                SR_out, deblock_out = self.model(self.eval_input,
                                                 x_out.detach())
                x_out = SR_out
                SR_out_list.append(SR_out)
                deblock_out_list.append(deblock_out)

            SR_out_ = quantize(SR_out_list[-1], self.opt.rgb_range)
            deblock_out_ = quantize(deblock_out_list[-1], self.opt.rgb_range)

        images = {
            'input': self.eval_input[0],
            'target': self.eval_target[0],
            'output': SR_out_[0],
            'deblock_output': deblock_out_[0]
        }
        return images
    def prototype(self):
        model = load_model(self._model_path)
        embedding_size = int(model.output.shape[-1])
        embeddings = []
        anchors = set()
        axesi = 0
        patches = []
        for image_annotation in self._image_annotations:
            support_image = WholeSlideImageASAP(image_annotation.image_path)
            for annotation in image_annotation.annotations:
                ratio = support_image.get_downsampling_from_spacing(
                    self._spacing)
                _, _, width, height = annotation.bounds[0], annotation.bounds[
                    1], annotation.bounds[2] - annotation.bounds[
                        0], annotation.bounds[3] - annotation.bounds[1]
                width, height = quantize(width // ratio, height // ratio,
                                         self._grid_cell_size)
                x, y = annotation.center
                anchors.add((width // 64, height // 64))
                support_patch = support_image.get_patch(
                    x, y, width, height, 0.5)
                patches.append(support_patch)
                blocks = skimage.util.view_as_blocks(
                    support_patch,
                    (64, 64, 3)).squeeze().reshape(-1, 64, 64, 3)
                # apply data-augmentations
                blocks_augmented1 = augmentor(blocks)
                blocks_augmented2 = augmentor(blocks)
                all_blocks = np.concatenate(
                    [blocks, blocks_augmented1, blocks_augmented2])
                embeddings.append(
                    model.predict_on_batch(
                        all_blocks / 255.0).squeeze().reshape(
                            -1, embedding_size).mean(axis=(0)))
            support_image.close()
            support_image = None
            del support_image
        proto = np.array(embeddings).mean(axis=0)
        # find threshold
        thresholds = []
        cos_thresholds = []
        for findex in range(len(embeddings)):
            # cos_thresholds = dot(embeddings[findex], embeddings[sindex])/(norm(embeddings[findex])*norm(embeddings[sindex]))
            thresholds.append(np.linalg.norm(embeddings[findex] - proto))

        thresholds = [t for t in thresholds if t != 0]
        return np.array(embeddings).mean(axis=0), self._calculate_threshold(
            thresholds), anchors, patches, thresholds
示例#18
0
def find_outer_radius(img_arr, center_point):
    cont = np.copy(img_arr)
    cont = contrast(cont, 2)
    grayscale, _ = conv_gray(cont)

    grayscale, palette = quantize(grayscale, 4)

    # threshold all values smaller than maximal
    grayscale[grayscale < palette[-2] + 2] = 0
    grayscale[grayscale > 0] = 255

    grayscale = erode(grayscale, 7)
    grayscale = dilate(grayscale, 7)

    R = estimate_radius(grayscale, center_point)
    return R
示例#19
0
    def validation_step(self, batch, batch_idx):
        x, y = batch

        x = quantize(x, self.centroids)
        x = _to_sequence(x)

        if self.classify:
            clf_logits = self.gpt(x, classify=True)
            loss = self.criterion(clf_logits, y)
            _, preds = torch.max(clf_logits, 1)
            correct = preds == y
            return {"val_loss": loss, "correct": correct}
        else:
            logits = self.gpt(x)
            loss = self.criterion(logits.view(-1, logits.size(-1)), x.view(-1))
            return {"val_loss": loss}
示例#20
0
    def training_step(self, batch, batch_idx):
        x, y = batch

        x = quantize(x, self.centroids)
        x = _to_sequence(x)

        if self.classify:
            # TODO: joint loss
            clf_logits = self.gpt(x, classify=True)
            loss = self.criterion(clf_logits, y)
        else:
            logits = self.gpt(x)
            loss = self.criterion(logits.view(-1, logits.size(-1)), x.view(-1))

        logs = {"loss": loss}
        return {"loss": loss, "log": logs}
示例#21
0
    def train(self):
        guassian = torch.FloatTensor([1, 2, 1, 2, 4, 5, 1, 2,
                                      1]).view(1, 1, 3, 3)
        guassian = torch.cat([guassian, guassian, guassian],
                             dim=0).to(self.target_down.device)

        blur_target = F.conv2d(self.target_down,
                               guassian,
                               stride=1,
                               padding=1,
                               groups=3)
        self.output = self.model(self.input, blur_target)

        loss = self.criterion(self.output, self.target_down)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.output = quantize(self.output, self.opt.rgb_range)

        return loss
示例#22
0
    def eval(self):
        with torch.no_grad():
            guassian = torch.FloatTensor([1, 2, 1, 2, 4, 5, 1, 2,
                                          1]).view(1, 1, 3, 3)
            guassian = torch.cat([guassian, guassian, guassian],
                                 dim=0).to(self.target_down.device)

            blur_target = F.conv2d(self.eval_target,
                                   guassian,
                                   stride=1,
                                   padding=1,
                                   groups=3)
            output = self.model(self.eval_input, blur_target)
            output = quantize(output, self.opt.rgb_range)

        return {
            'input': self.eval_input[0],
            'output': output[0],
            'target': self.eval_target[0]
        }
示例#23
0
def main(args):
    model = ImageGPT.load_from_checkpoint(args.checkpoint).gpt.eval().cuda()

    centroids = np.load(args.centroids)
    train_dl, valid_dl, test_dl = dataloaders(args.dataset, 1)

    dl = iter(DataLoader(valid_dl.dataset, shuffle=True))

    # rows for figure
    rows = []

    for example in tqdm(range(args.num_examples)):
        img, _ = next(dl)
        h, w = img.shape[-2:]

        img = quantize(img, torch.from_numpy(centroids)).numpy()[0]
        seq = img.reshape(-1)

        # first half of image is context
        context = seq[: int(len(seq) / 2)]
        context_img = np.pad(context, (0, int(len(seq) / 2))).reshape(h, w)
        context = torch.from_numpy(context).cuda()

        # predict second half of image
        preds = (
            sample(model, context, int(len(seq) / 2), num_samples=args.num_samples)
            .cpu()
            .numpy()
            .transpose()
        )

        preds = preds.reshape(-1, h, w)

        # combine context, preds, and truth for figure
        rows.append(
            np.concatenate([context_img[None, ...], preds, img[None, ...]], axis=0)
        )

    figure = make_figure(rows, centroids)
    figure.save("figure.png")
def retrain(net, epoch, mask):
    print('\nEpoch: %d' % epoch)
    global best_acc
    net.train()
    train_loss = 0
    total = 0
    correct = 0
    mask_channel = torch.load('mask_null.dat')
    mask_channel = utils.setMask(utils.setMask(mask_channel, 0, 1), 2, 0)
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()

        if args.fixed:
            net = utils.quantize(net, args.pprec)

        net = utils.netMaskMul(net, mask_channel)

        net = utils.pruneNetwork(net, mask)

        optimizer.step()

        train_loss += loss.data.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += float(predicted.eq(targets.data).cpu().sum())

        progress_bar(
            batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (train_loss /
             (batch_idx + 1), 100. * correct / total, correct, total))
        acc = 100. * correct / total
示例#25
0
    def eval(self):
        with torch.no_grad():
            output = self.model(self.eval_input_y, self.eval_input_bicu_y)
            output = quantize(output, self.opt.rgb_range)

        self.output = self.eval_input_bicu.data.clone()
        self.output[:, 0, :, :] = output
        self.output = self.output[0].cpu().permute(1, 2, 0).numpy()
        self.output = ycbcr2rgb(self.output)
        self.output = torch.from_numpy(self.output).permute(2, 0, 1)

        self.eval_input = self.eval_input[0].cpu().permute(1, 2, 0).numpy()
        self.eval_input = ycbcr2rgb(self.eval_input)
        self.eval_input = torch.from_numpy(self.eval_input).permute(2, 0, 1)

        self.eval_target = self.eval_target[0].cpu().permute(1, 2, 0).numpy()
        self.eval_target = ycbcr2rgb(self.eval_target)
        self.eval_target = torch.from_numpy(self.eval_target).permute(2, 0, 1)

        return {
            'input': self.eval_input,
            'output': self.output,
            'target': self.eval_target
        }
示例#26
0
def generate_quantized_maps_from_dict(maps):
    """This function quantizes the maps
    
    Arguments:
        maps {dict} -- keys in ('train', 'test'), values are lists of tuples
            ('file_name', 'numpy_map', 'split_id', 'genre')

    Returns:
        quantized_maps {dict} -- keys in ('train', 'test'), values are lists 
        of tuples ('file_name', 'quantized_map', 'split_id', 'genre')
    """

    params = load_params()

    quantized_maps = {"train": None, "test": None}

    for split in quantized_maps:
        quantized_maps[split] = [
            (piece[0],
             quantize(piece[1], n_levels=params["quantization"]["n_levels"]),
             piece[2], piece[3]) for piece in maps[split]
        ]

    return quantized_maps
示例#27
0
def pie(xts, yts, ytothers, bounds=None, subtitle=None, title=None):

    if bounds==None:
        colors = colorwheel(5)
    else:
        colors = colorwheel(len(bounds))

    figure = plot.figure()
    for index,(xt,yt,yother) in enumerate(zip(xts,yts,ytothers)):

        # sites OPEN in ESC and OPEN in atleast one other cell
        subplot = figure.add_subplot(4,4,4*index+1, aspect='equal')
        [spine.set_linewidth(0.1) for spine in subplot.spines.values()]
        if bounds==None:
            quantized = utils.quantize(xt[yt*yother==1], q=5)
            proportions = [q.size for q in quantized]
        else:
            proportions = [((xt>=bound[0])*(xt<bound[1])*(yt*yother==1)).sum() for bound in bounds]
        patches, texts = subplot.pie(proportions, labels=map(str,proportions), colors=colors, labeldistance=1.2)
        for text in texts:
            text.set_fontsize(8)

        if subtitle:
            bbox = subplot.get_position()
            xloc = bbox.xmin/2.
            yloc = (bbox.ymax+bbox.ymin)/2.
            plot.text(xloc, yloc, subtitle[index], fontsize=8, horizontalalignment='center', \
                verticalalignment='center', transform=figure.transFigure)

        if index==0:
            bbox = subplot.get_position()
            xloc = (bbox.xmax+bbox.xmin)/2.
            yloc = (3*bbox.ymax+1)/4.
            plot.text(xloc, yloc, 'ESC && OTHER', fontsize=8, horizontalalignment='center', \
                verticalalignment='bottom', transform=figure.transFigure)

        # sites OPEN in ESC and CLOSED in all other cells
        subplot = figure.add_subplot(4,4,4*index+2, aspect='equal')
        [spine.set_linewidth(0.1) for spine in subplot.spines.values()]
        if bounds==None:
            quantized = utils.quantize(xt[yt*(1-yother)==1], q=5)
            proportions = [q.size for q in quantized]
        else:
            proportions = [((xt>=bound[0])*(xt<bound[1])*(yt*(1-yother)==1)).sum() for bound in bounds]
        patches, texts = subplot.pie(proportions, labels=map(str,proportions), colors=colors, labeldistance=1.2)
        for text in texts:
            text.set_fontsize(8)

        if index==0:
            bbox = subplot.get_position()
            xloc = (bbox.xmax+bbox.xmin)/2.
            yloc = (3*bbox.ymax+1)/4.
            plot.text(xloc, yloc, 'ESC && !OTHER', fontsize=8, horizontalalignment='center', \
                verticalalignment='bottom', transform=figure.transFigure)

        # sites CLOSED in ESC and OPEN in atleast one other cell
        subplot = figure.add_subplot(4,4,4*index+3, aspect='equal')
        [spine.set_linewidth(0.1) for spine in subplot.spines.values()]
        if bounds==None:
            quantized = utils.quantize(xt[(1-yt)*yother==1], q=5)
            proportions = [q.size for q in quantized]
        else:
            proportions = [((xt>=bound[0])*(xt<bound[1])*((1-yt)*yother==1)).sum() for bound in bounds]
        patches, texts = subplot.pie(proportions, labels=map(str,proportions), colors=colors, labeldistance=1.2)
        for text in texts:
            text.set_fontsize(8)

        if index==0:
            bbox = subplot.get_position()
            xloc = (bbox.xmax+bbox.xmin)/2.
            yloc = (3*bbox.ymax+1)/4.
            plot.text(xloc, yloc, '!ESC && OTHER', fontsize=8, horizontalalignment='center', \
                verticalalignment='bottom', transform=figure.transFigure)

        # sites CLOSED in ESC and CLOSED in all other cells
        subplot = figure.add_subplot(4,4,4*index+4, aspect='equal')
        [spine.set_linewidth(0.1) for spine in subplot.spines.values()]
        if bounds==None:
            quantized = utils.quantize(xt[(1-yt)*(1-yother)==1], q=5)
            proportions = [q.size for q in quantized]
        else:
            proportions = [((xt>=bound[0])*(xt<bound[1])*((1-yt)*(1-yother)==1)).sum() for bound in bounds]
        patches, texts = subplot.pie(proportions, labels=map(str,proportions), colors=colors, labeldistance=1.2)
        for text in texts:
            text.set_fontsize(8)

        if index==0:
            bbox = subplot.get_position()
            xloc = (bbox.xmax+bbox.xmin)/2.
            yloc = (3*bbox.ymax+1)/4.
            plot.text(xloc, yloc, '!ESC && !OTHER', fontsize=8, horizontalalignment='center', \
                verticalalignment='bottom', transform=figure.transFigure)

    if title:
        plot.suptitle(title)

    return figure
示例#28
0
        if h % 4 == 0 and w % 4 == 0:
            start.record()
            out = model(im_input)
            end.record()
            torch.cuda.synchronize()
            time_list[i] = start.elapsed_time(end)  # milliseconds
        else:
            start.record()
            out = crop_forward(im_input, model)
            end.record()
            torch.cuda.synchronize()
            time_list[i] = start.elapsed_time(end)  # milliseconds

    sr_img = utils.tensor2np(out.detach()[0])
    if opt.is_y is True:
        im_label = utils.quantize(sc.rgb2ycbcr(im_gt)[:, :, 0])
        im_pre = utils.quantize(sc.rgb2ycbcr(sr_img)[:, :, 0])
    else:
        im_label = im_gt
        im_pre = sr_img
    psnr_list[i] = utils.compute_psnr(im_pre, im_label)
    ssim_list[i] = utils.compute_ssim(im_pre, im_label)

    output_folder = os.path.join(opt.output_folder, imname.split('/')[-1])

    if not os.path.exists(opt.output_folder):
        os.makedirs(opt.output_folder)

    sio.imsave(output_folder, sr_img)
    i += 1
示例#29
0
def plot_reads(Reads, xts, motiflen=None, q=5, quantile=True, bounds=None, subtitle=None, cellnames=None, title=None, hist=False):

    """
    reads is a list of list of binary tuples of lists of numpy array with read counts. The outermost
    list is over different cell types. The next inner list
    is over different TSS dist thresholds. The binary tuple is for same / opposite strands.
    """

    width = Reads[0][0][0][0].size
    figure = plot.figure()
    xvals = np.arange(-width/2,width/2)
    numcols = len(Reads)
    numrows = len(Reads[0])
    if hist:
        numrows = 2*numrows

    colors = colorwheel(q)

    for cellidx, reads in enumerate(Reads): 

        for index, (xt,read) in enumerate(zip(xts,reads)):

            if quantile:
                quantized = utils.quantiles(xt, q=q)
            else:
                quantized = utils.quantize(xt, q=q, bounds=bounds)
            same = [np.mean([read[0][idx] for idx in quant if read[0][idx].size==width],0) for quant in quantized]
            opp = [-1*np.mean([read[1][idx] for idx in quant if read[1][idx].size==width],0) for quant in quantized]
            if hist:
                subplot = figure.add_subplot(numrows,numcols,2*index*numcols+cellidx+1)
            else:
                subplot = figure.add_subplot(numrows,numcols,index*numcols+cellidx+1)
            subplot = remove_spines(subplot)

            fwd = [subplot.plot(xvals, s, color=c, linestyle='-', linewidth=0.5) for s,c in zip(same,colors)]
            rev = [subplot.plot(xvals, o, color=c, linestyle='-', linewidth=0.5) for o,c in zip(opp,colors)]
            subplot.axhline(0, linestyle='--', linewidth=0.2)
            subplot.axvline(0, linestyle='--', linewidth=0.2)

            if motiflen:
                subplot.axvline(motiflen-1, linestyle='--', c='g', linewidth=0.2)

            xmin = xvals[0]
            xmax = xvals[-1]
            ymax = max([s.max() for s in same])
            ymin = min([o.min() for o in opp])
            subplot.axis([xmin, xmax, ymin, ymax])

            for text in subplot.get_xticklabels():
                text.set_fontsize(7)
                text.set_verticalalignment('center')

            ytick_locs = list(np.linspace(np.round(ymin,2),np.round(ymax,2),5))
            if 0 not in ytick_locs:
                ytick_locs.append(0)
                ytick_locs.sort()
            ytick_labels = tuple(['%.2f'%s for s in ytick_locs])
            subplot.set_yticks(ytick_locs)
            subplot.set_yticklabels(ytick_labels, color='k', fontsize=6, horizontalalignment='right')

            if subtitle and cellidx==0:
                bbox = subplot.get_position()
                xloc = bbox.xmin/3.
                yloc = (bbox.ymax+bbox.ymin)/2.
                plot.text(xloc, yloc, subtitle[index], fontsize=8, horizontalalignment='center', \
                    verticalalignment='center', transform=figure.transFigure)

            if cellnames and index==0:
                bbox = subplot.get_position()
                xloc = (bbox.xmax+bbox.xmin)/2.
                yloc = (3*bbox.ymax+1)/4.
                plot.text(xloc, yloc, cellnames[cellidx], fontsize=8, horizontalalignment='center', \
                    verticalalignment='bottom', transform=figure.transFigure)

            if hist:
                subplot = figure.add_subplot(numrows,numcols,(2*index+1)*numcols+cellidx+1)
                subplot = remove_spines(subplot)

                reads_unbound = np.power([read[0][idx].sum()+read[1][idx].sum() for idx in quantized[0] \
                    if read[0][idx].size==width and read[1][idx].size==width], 0.25)
                reads_bound = np.power([read[0][idx].sum()+read[1][idx].sum() for idx in quantized[-1] \
                    if read[0][idx].size==width and read[1][idx].size==width], 0.25)

                h0 = subplot.hist(reads_unbound, bins=200, color=colors[0], histtype='step', linewidth=0.2, normed=True)
                h1 = subplot.hist(reads_bound, bins=200, color=colors[-1], histtype='step', linewidth=0.2, normed=True)

                xmin = 0
                xmax = max([reads_bound.max(), reads_unbound.max()])
                ymin = 0
                ymax = max([h0[0].max(), h1[0].max()])
                subplot.axis([xmin, xmax, ymin, ymax])

                for text in subplot.get_xticklabels():
                    text.set_fontsize(7)
                    text.set_verticalalignment('center')

                ytick_locs = list(np.linspace(np.round(ymin,2),np.round(ymax,2),5))
                ytick_labels = tuple(['%.2f'%s for s in ytick_locs])
                subplot.set_yticks(ytick_locs)
                subplot.set_yticklabels(ytick_labels, color='k', fontsize=6, horizontalalignment='right')

                subplot.set_xlabel('Fourth root of total reads', fontsize=6, horizontalalignment='center')

    legends = ['(%.2f,%.2f)'%(xt[quant].min(),xt[quant].max()) for quant in quantized]
    leghandle = plot.figlegend(fwd, legends, loc='lower right', mode="expand", ncol=q)
    for text in leghandle.texts:
        text.set_fontsize(6)
    leghandle.set_frame_on(False)
    
    if title:
        plot.suptitle(title, fontsize=10)

    return figure
示例#30
0
def test():
    df_column = ['Name']
    df_column.extend([str(i) for i in range(1, seq_len + 1)])

    df = pd.DataFrame(columns=df_column)

    psnr_array = np.zeros((0, seq_len))
    ssim_array = np.zeros((0, seq_len))

    tqdm_loader = tqdm.tqdm(validationloader, ncols=80)

    imgsave_folder = os.path.join(args.checkpoint_dir, 'Saved_imgs')
    if not os.path.exists(imgsave_folder):
        os.mkdir(imgsave_folder)

    with torch.no_grad():
        for validationIndex, (validationData, validationFrameIndex,
                              validationFile) in enumerate(tqdm_loader):

            blurred_img = torch.zeros_like(validationData[0])
            for image in validationData:
                blurred_img += image
            blurred_img /= len(validationData)
            blurred_img = blurred_img.to(device)
            batch_size = blurred_img.shape[0]

            blurred_img = meanshift(blurred_img, mean, std, device, False)
            c = center_estimation(blurred_img)
            start, end = border_estimation(blurred_img, c)
            start = meanshift(start, mean, std, device, True)
            end = meanshift(end, mean, std, device, True)
            blurred_img = meanshift(blurred_img, mean, std, device, True)

            frame0 = validationData[0].to(device)
            frame1 = validationData[-1].to(device)

            batch_size = blurred_img.shape[0]
            parallel = torch.mean(compare_ftn(start, frame0) +
                                  compare_ftn(end, frame1),
                                  dim=(1, 2, 3))
            cross = torch.mean(compare_ftn(start, frame1) +
                               compare_ftn(end, frame0),
                               dim=(1, 2, 3))

            I0 = torch.zeros_like(blurred_img)
            I1 = torch.zeros_like(blurred_img)
            for b in range(batch_size):
                if parallel[b] <= cross[b]:
                    I0[b], I1[b] = start[b], end[b]
                else:
                    I0[b], I1[b] = end[b], start[b]

            psnrs = np.zeros((batch_size, seq_len))
            ssims = np.zeros((batch_size, seq_len))

            for vindex in range(seq_len):
                frameT = validationData[vindex]
                IFrame = frameT.to(device)

                if vindex == 0:
                    Ft_p = I0.clone()

                elif vindex == seq_len - 1:
                    Ft_p = I1.clone()

                else:
                    validationIndex = torch.ones(batch_size) * (vindex - 1)
                    validationIndex = validationIndex.long()
                    flowOut = flowComp(torch.cat((I0, I1), dim=1))
                    F_0_1 = flowOut[:, :2, :, :]
                    F_1_0 = flowOut[:, 2:, :, :]

                    fCoeff = superslomo.getFlowCoeff(validationIndex, device,
                                                     seq_len)

                    F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                    F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                    g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)
                    g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)

                    if args.add_blur:
                        intrpOut = ArbTimeFlowIntrp(
                            torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0,
                                       g_I1_F_t_1, g_I0_F_t_0, blurred_img),
                                      dim=1))
                    else:
                        intrpOut = ArbTimeFlowIntrp(
                            torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0,
                                       g_I1_F_t_1, g_I0_F_t_0),
                                      dim=1))

                    F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                    F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                    V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
                    V_t_1 = 1 - V_t_0

                    g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)
                    g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)

                    wCoeff = superslomo.getWarpCoeff(validationIndex, device,
                                                     seq_len)

                    Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f +
                            wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (
                                wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                Ft_p = meanshift(Ft_p, mean, std, device, False)
                IFrame = meanshift(IFrame, mean, std, device, False)

                for b in range(batch_size):
                    foldername = os.path.basename(
                        os.path.dirname(validationFile[ctr_idx][b]))
                    filename = os.path.splitext(
                        os.path.basename(validationFile[vindex][b]))[0]

                    out_fname = foldername + '_' + filename + '_out.png'
                    gt_fname = foldername + '_' + filename + '.png'
                    out, gt = quantize(Ft_p[b]), quantize(IFrame[b])

                    # Comment two lines below if you want to save images
                    # torchvision.utils.save_image(out, os.path.join(imgsave_folder, out_fname), normalize=True, range=(0,255))
                    # torchvision.utils.save_image(gt, os.path.join(imgsave_folder, gt_fname), normalize=True, range=(0,255))

                psnr, ssim = eval_metrics(Ft_p, IFrame)
                psnrs[:, vindex] = psnr.cpu().numpy()
                ssims[:, vindex] = ssim.cpu().numpy()

            for b in range(batch_size):
                rows = [validationFile[ctr_idx][b]]
                rows.extend(list(psnrs[b]))
                df = df.append(pd.Series(rows, index=df.columns),
                               ignore_index=True)

            df.to_csv('{}/results_PSNR.csv'.format(args.checkpoint_dir))
示例#31
0
                [psnr, ssim, output_image] = sess.run([psnr_, ssim_, out_],
                                                      feed_dict={
                                                          x_: image_bicubic,
                                                          y_: image_target,
                                                          h_: h,
                                                          w_: w
                                                      })

                psnr_score += psnr / num_val_images
                ssim_score += ssim / num_val_images

                if not os.path.exists(output_folder):
                    os.makedirs(output_folder, exist_ok=True)
                sio.imsave(
                    output_folder + validation_images[j],
                    utils.quantize(np.squeeze(output_image, axis=0) * 255.0))

            print("\r\r\r")
            print("Scores | PSNR: %.4g, MS-SSIM: %.4g" %
                  (psnr_score, ssim_score))
            print("\n-------------------------------------\n")
            sess.close()

    if compute_running_time:

        ##############################
        #  3 Computing running time  #
        ##############################

        print("Evaluating model speed")
        print("This can take a few minutes\n")
示例#32
0
    def build_arch(self):
        with tf.variable_scope('Conv1_layer'):
            # Conv1, [batch_size, 28, 28, 32]
            W1 = tf.get_variable('Weight1', shape=(5, 5, 1, 32), dtype=tf.float32,
                                initializer=tf.random_normal_initializer(stddev=cfg.stddev))
            biases1 = tf.get_variable('bias1', shape=(32))

            '''
            if not cfg.is_training:
                W1 = quantize(W1, cfg.bits)
                biases1 = quantize(biases1, cfg.bits)
            '''

            conv1 = tf.nn.relu(tf.nn.conv2d(self.X, W1, strides=[1, 1, 1, 1], padding='SAME') + biases1)
            assert conv1.get_shape() == [cfg.batch_size, 28, 28, 32]

        with tf.variable_scope('Pooling1_layer'):
            # Pooling1, [batch_size, 14, 14, 32]
            pool1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
            assert pool1.get_shape() == [cfg.batch_size, 14, 14, 32]

        with tf.variable_scope('Conv2_layer'):
            # Conv2, [batch_size, 14, 14, 64]
            W2 = tf.get_variable('Weight2', shape=(5, 5, 32, 64), dtype=tf.float32,
                                initializer=tf.random_normal_initializer(stddev=cfg.stddev))
            biases2 = tf.get_variable('bias1', shape=(64))

            '''
            if not cfg.is_training:
                W2 = quantize(W2, cfg.bits)
                biases2 = quantize(biases2, cfg.bits)
            '''

            conv2 = tf.nn.relu(tf.nn.conv2d(pool1, W2, strides=[1, 1, 1, 1], padding='SAME') + biases2)
            assert conv2.get_shape() == [cfg.batch_size, 14, 14, 64]

        with tf.variable_scope('Pooling2_layer'):
            # Pooling1, [batch_size, 7, 7, 64]
            pool2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
            assert pool2.get_shape() == [cfg.batch_size, 7, 7, 64]

        with tf.variable_scope('FC1_layer'):
            # FC1, [batch_size, 1024]
            W3 = tf.get_variable('Weight3', shape=(7 * 7 * 64, 1024), dtype=tf.float32, initializer=tf.random_normal_initializer(stddev=cfg.stddev))
            biases3 = tf.get_variable('bias3', shape=(1024))

            if not cfg.is_training:
                W3 = quantize(W3, cfg.bits)
                biases3 = quantize(biases3, cfg.bits)

            flatten = tf.reshape(pool2, [-1, 7 * 7 * 64])
            fc1 = tf.nn.relu(tf.matmul(flatten, W3) + biases3)
            assert fc1.get_shape() == [cfg.batch_size, 1024]

        with tf.variable_scope('Dropout'):
            keep_prob = 0.5 if cfg.is_training else 1.0
            dropout = tf.nn.dropout(fc1, keep_prob=keep_prob)

        with tf.variable_scope('FC2_layer'):
            # FC1, [batch_size, 1024]
            W4 = tf.get_variable('Weight4', shape=(1024, 10), dtype=tf.float32, initializer=tf.random_normal_initializer(stddev=cfg.stddev))
            biases4 = tf.get_variable('bias4', shape=(10))

            if not cfg.is_training:
                W4 = quantize(W4, cfg.bits)
                biases4 = quantize(biases4, cfg.bits)

            # self.softmax_v = tf.nn.softmax(tf.matmul(dropout, W4) + biases4)
            self.softmax_v = tf.matmul(dropout, W4) + biases4
            assert self.softmax_v.get_shape() == [cfg.batch_size, 10]

        self.argmax_idx = tf.to_int32(tf.argmax(self.softmax_v, axis=1))
        assert self.argmax_idx.get_shape() == [cfg.batch_size]