Пример #1
0
 def backward(self, dout):
     FN, C, FH, FW = self.W.shape
     dout = dout.transpose(0,2,3,1).reshape(-1, FN)
     
     self.db = np.sum(dout, axis=0)
     self.dW = np.dot(self.col.T, dout)
     self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)
     
     dcol = np.dot(dout, self.col_W.T)
     dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)
     
     return dx
Пример #2
0
 def backward(self, dout):
     dout = dout.transpose(0, 2, 3, 1)
     
     pool_size = self.pool_h * self.pool_w
     dmax = np.zeros((dout.size, pool_size))
     dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
     dmax = dmax.reshape(dout.shape + (pool_size,))
     
     dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
     dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)
     
     return dx
Пример #3
0
    def backward(self, dout):
        FN, FC, FH, FW = self.W.shape
        W_col = self.W.reshape(FN, -1)

        dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)
        self.dW = self.im_col.T.dot(dout)
        self.dW = self.dW.transpose(1, 0).reshape(FN, FC, FH, FW)
        self.db = dout.sum(axis=0)

        dx = dout.dot(W_col)
        dx = col2im(dx, self.xshape, FH, FW, self.stride, self.padding)
        return dx
Пример #4
0
    def backward(self, dout):
        FN, C, FH, FW = self.par['w'].shape
        dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)

        self.gra['b'] = np.sum(dout, axis=0)
        self.gra['w'] = np.dot(self.col.T, dout)
        self.gra['w'] = self.gra['w'].transpose(1, 0).reshape(FN, C, FH, FW)

        dcol = np.dot(dout, self.col_W.T)
        dx = col2im(dcol, self.x.shape, FH, FW, self.strid, self.pading)

        return dx
Пример #5
0
    def backward(self, dout):
        # even the dout is from Affine layer, we already reshape it to
        # the original shape, so here is dout shape is(N,C,H,W)

        # we transpose the dout, to fit im_col
        dout = dout.transpose(0, 2, 3, 1)
        col = np.zeros((dout.size, self.pool_h*self.pool_w))
        col[np.arange(dout.size), self.argmax] = dout.flatten()
        col = col.reshape(-1, self.pool_h*self.pool_w*self.out_shape[1])
        #print(f"my col:{col}")
        im = col2im(col, self.img_shape, self.pool_h,
                    self.pool_w, self.stride, self.padding)

        return im
Пример #6
0
def test_im2col():
    """
    x1 = np.random.rand(1, 3, 7, 7)
    col1 = im2col(x1, 5, 5, stride=1, pad=0)
    print(f'the shape of result is {col1.shape}')

    x1 = np.random.rand(10, 3, 7, 7)
    col1 = im2col(x1, 5, 5, stride=1, pad=0)
    print(f'the shape of result is {col1.shape}')
    """
    col = np.arange(90*75).reshape(90, 75)
    r1 = col2im(col, (10, 3, 7, 7), 5, 5, stride=1, pad=0)
    r2 = book_util.col2im(col, (10, 3, 7, 7), 5, 5, stride=1, pad=0)
    print(f"test result of test_im2col: {(r1==r2).all()}")
    def backward(self, inputs, output_errors):

        # Biases gradients
        self.g_biases += output_errors.sum(axis=(1, 2)).reshape(
            output_errors.shape[0], 1)

        # Weights gradients
        erros_reshaped = output_errors.reshape(self.outputs_depth, -1)
        self.g_weights = np.dot(erros_reshaped, self.X_col)
        self.g_weights = self.g_weights.reshape(self.weights.shape)

        # Inputs gradients
        W_reshaped = self.weights.reshape(self.outputs_depth, -1)
        g_input_col = np.dot(W_reshaped.T, erros_reshaped)
        g_input = col2im(g_input_col, inputs.shape, self.X_col_indices.T)

        return g_input
Пример #8
0
    def backward(self, dy):
        N, C, out_h, out_w = dy.shape

        # dy.shape: (N, C, out_h, out_w)
        # after transpose: (N, out_h, out_w, C)
        dy = dy.transpose(0, 2, 3, 1)

        # col.shape: (N*out_h*out_w*C, pool_h*pool_w)
        # after reshape: (N*out_h*out_w, C*pool_h*pool_w)
        col = np.zeros((N*out_h*out_w*C, self.pool_h*self.pool_w))
        col[np.arange(N*out_h*out_w*C), self.argmax.flatten()] = dy.flatten()
        col = col.reshape(N*out_h*out_w, -1)

        # col.shape: (N*out_h*out_w, C*pool_h*pool_w)
        # dx.shape (N, C, H, W)
        dx = col2im(col, self.x_shape, self.pool_h, self.pool_w, self.stride, self.pad)
        return dx
Пример #9
0
 def backward(self, dy):
     # dy.shape: (N, FN, OH, OW)
     # after transpose: (N, OH, OW, FN)
     # after reshape: (N*OH*OW, FN)
     N, FN, OH, OW = dy.shape
     dy=dy.transpose(0, 2, 3, 1).reshape(-1, FN)
     
     # y = x.w + b => dy/dx = w, dy/dw = x, dy/db = 1
     # dy.shape:                     (N*OH*OW, FN)
     # dcol.shape(same as col.shape):(N*OH*OW, C*FH*FW)
     # dw.shape(same as W.shape):    (C*FH*FW, FN)
     # db.shape(same as b.shape):    (FN,)
     # dx.shape(same as x.shape):    (N, C, H, W)
     self.dW = np.dot(self.col.T, dy)
     self.db = np.sum(dy, axis=0)
     dcol = np.dot(dy, self.W.T)
     dx = col2im(dcol, self.x_shape, self.FH, self.FW, self.stride, self.pad)
     return dx
Пример #10
0
	def backward(self, output_errors):

		assert output_errors.shape == self.output_shape

		grad_b = np.sum(output_errors, axis=(0, 2, 3))
		grad_b = np.reshape(grad_b, (self.output_depth, 1))

		output_errors = np.transpose(output_errors, (1, 2, 3, 0))
		output_errors = np.reshape(output_errors, (self.output_depth, -1))
		
		W_reshaped = np.reshape(self.W, (self.output_depth, -1))
		grad_x_blocks = np.dot(output_errors.T, W_reshaped).T
		grad_x = col2im(grad_x_blocks, self.input_shape, self.filter_size, \
									   self.filter_size, padding=0, stride=self.stride)
		grad_x = grad_x[:,:,self.pad1:-self.pad2,self.pad1:-self.pad2]

		grad_W = np.dot(output_errors, self.X_blocks.T)
		grad_W = np.reshape(grad_W, self.W.shape)

		self.grad_W = grad_W
		self.grad_b = grad_b

		return grad_x
Пример #11
0
def run(model_name,
        base_dir,
        data_dir,
        task='sc',
        train_batch_size=128,
        eval_batch_size=1024,
        epochs=125,
        mode='train'):

    _NUM_TRAIN_IMAGES = FLAGS.num_train_images
    _NUM_EVAL_IMAGES = FLAGS.num_val_images

    if task == 'sc' or task == 'lasso':
        training_steps_per_epoch = int(_NUM_TRAIN_IMAGES // train_batch_size)
        validation_steps_per_epoch = int(_NUM_EVAL_IMAGES // eval_batch_size)
    elif task == 'cs':
        training_steps_per_epoch = 3125
        validation_steps_per_epoch = 10

    _BASE_LR = FLAGS.base_lr

    # Deal with paths of data, checkpoints and logs
    base_dir = os.path.abspath(base_dir)
    model_dir = os.path.join(base_dir, 'models', FLAGS.exp_name,
                             'replicate_' + str(FLAGS.replicate))
    log_dir = os.path.join(base_dir, 'logs', FLAGS.exp_name,
                           'replicate_' + str(FLAGS.replicate))
    logging.info('Saving checkpoints at %s', model_dir)
    logging.info('Saving tensorboard summaries at %s', log_dir)
    logging.info('Use training batch size: %s.', train_batch_size)
    logging.info('Use eval batch size: %s.', eval_batch_size)
    logging.info('Training model using data_dir in directory: %s', data_dir)

    if task == 'sc' or task == 'lasso':
        A = np.load(os.path.join(data_dir, 'A.npy'),
                    allow_pickle=True).astype(np.float32)
        M, N = A.shape
        F = None
        D = None
    elif task == 'cs':
        A = np.load(os.path.join(data_dir, 'A_128_512.npy'),
                    allow_pickle=True).astype(np.float32)
        D = np.load(os.path.join(data_dir, 'D_256_512.npy'),
                    allow_pickle=True).astype(np.float32)
        N = D.shape[0]
        F = D.shape[1]
    else:
        raise ValueError('invalid task type')

    if FLAGS.model_name.startswith('alista'):
        alista_W = np.load(os.path.join(data_dir, 'W.npy'),
                           allow_pickle=True).astype(np.float32)

    np.random.seed(FLAGS.seed)

    if mode == 'train':
        train_dataset = data_preprocessing.input_fn(True,
                                                    data_dir,
                                                    train_batch_size,
                                                    task,
                                                    drop_remainder=False,
                                                    A=A)
        val_dataset = data_preprocessing.input_fn(False,
                                                  data_dir,
                                                  eval_batch_size,
                                                  task,
                                                  drop_remainder=False,
                                                  A=A)

    summary_writer = tf.summary.create_file_writer(log_dir)

    # Define a Lista model
    if FLAGS.model_name == 'lista':
        model = models.Lista(A,
                             FLAGS.num_layers,
                             FLAGS.model_lam,
                             FLAGS.share_W,
                             D,
                             name='Lista')
        output_interval = N
    elif FLAGS.model_name == 'lfista':
        model = models.Lfista(A,
                              FLAGS.num_layers,
                              FLAGS.model_lam,
                              FLAGS.share_W,
                              D,
                              name='Lfista')
        output_interval = N
    elif FLAGS.model_name == 'lamp':
        model = models.Lamp(A,
                            FLAGS.num_layers,
                            FLAGS.model_lam,
                            FLAGS.share_W,
                            D,
                            name='Lamp')
        output_interval = M + N
    elif FLAGS.model_name == 'step_lista':
        assert FLAGS.model_lam == FLAGS.lasso_lam
        model = models.StepLista(A,
                                 FLAGS.num_layers,
                                 FLAGS.lasso_lam,
                                 D,
                                 name='StepLista')
        output_interval = N
    elif FLAGS.model_name == 'lista_cp':
        model = models.ListaCp(A,
                               FLAGS.num_layers,
                               FLAGS.model_lam,
                               FLAGS.share_W,
                               D,
                               name='ListaCp')
        output_interval = N
    elif FLAGS.model_name == 'lista_cpss':
        model = models.ListaCpss(A,
                                 FLAGS.num_layers,
                                 FLAGS.model_lam,
                                 FLAGS.ss_q_per_layer,
                                 FLAGS.ss_maxq,
                                 FLAGS.share_W,
                                 D,
                                 name='ListaCpss')
        output_interval = N
    elif FLAGS.model_name == 'alista':
        model = models.Alista(A,
                              alista_W,
                              FLAGS.num_layers,
                              FLAGS.model_lam,
                              FLAGS.ss_q_per_layer,
                              FLAGS.ss_maxq,
                              D,
                              name='Alista')
        output_interval = N
    elif FLAGS.model_name == 'glista':
        model = models.Glista(A,
                              FLAGS.num_layers,
                              FLAGS.model_lam,
                              FLAGS.ss_q_per_layer,
                              FLAGS.ss_maxq,
                              FLAGS.share_W,
                              D,
                              name='Glista',
                              alti=FLAGS.glista_alti,
                              gain_func=FLAGS.gain_func)
        output_interval = N * 2
    elif FLAGS.model_name == 'tista':
        assert FLAGS.tista_sigma2 is not None
        model = models.Tista(A,
                             FLAGS.num_layers,
                             FLAGS.model_lam,
                             FLAGS.tista_sigma2,
                             FLAGS.share_W,
                             D,
                             name='Tista')
        output_interval = N
    else:
        raise NotImplementedError(
            'Other types of models not are not implemented yet')

    var_list = {}
    checkpoint = tf.train.Checkpoint(model=model)
    prev_model, prev_layer = utils.check_and_load_partial(
        model_dir, FLAGS.num_layers)

    if task == 'lasso':
        _A_const = tf.constant(A, name='A_lasso_const')
        loss = utils.LassoLoss(_A_const, FLAGS.lasso_lam, N, F)
        metrics_compile = [
            utils.LassoObjective('lasso', _A_const, FLAGS.lasso_lam, M, N, -1)
        ]
        monitor = 'val_lasso'
    else:
        loss = utils.MSE(N, F)
        metrics_compile = [utils.NMSE('nmse', N)]
        monitor = 'val_nmse'

    if mode == 'test':
        if prev_layer != FLAGS.num_layers:
            raise ValueError('Should have a fully trained model!')
        checkpoint.restore(prev_model).assert_existing_objects_matched()
        res_dict = {}
        eval_files = FLAGS.test_files
        if task == 'lasso':
            # do layer-wise testing
            test_metrics = [
                utils.LassoObjective('lasso_layer{}'.format(i), _A_const,
                                     FLAGS.lasso_lam, M, N, i)
                for i in range(FLAGS.num_layers)
            ]
        else:
            test_metrics = [
                utils.EvalNMSE('nmse_layer{}'.format(i), M, N, output_interval,
                               i) for i in range(FLAGS.num_layers)
            ]
        for layer_id in range(FLAGS.num_layers):
            model.create_cell(layer_id)
        for i in range(len(eval_files)):
            val_ds = data_preprocessing.input_fn(False,
                                                 data_dir,
                                                 eval_batch_size,
                                                 drop_remainder=False,
                                                 A=A,
                                                 filename=eval_files[i])
            logging.info('Compiling model.')
            model.compile(optimizer=tf.keras.optimizers.Adam(_BASE_LR),
                          loss=loss,
                          metrics=test_metrics)
            metrics = model.evaluate(x=val_ds, verbose=2)
            if task == 'lasso':
                output = model.predict(x=val_ds, verbose=2)
                final_xh = output[:, -N:]
                eval_file_basename = os.path.basename(
                    eval_files[i]).strip('.npy')
                np.save(
                    os.path.join(model_dir,
                                 eval_file_basename + '_final_output.npy'),
                    final_xh)
            res_dict[eval_files[i]] = metrics[1:]
        for k, v in res_dict.items():
            logging.info('%s : %s', k, str(v))
        return

    for layer_id in range(FLAGS.num_layers):
        logging.info('Building Lista Keras model.')
        model.create_cell(layer_id)

        # Deal with the variables that have been trained in previous layers
        for name in var_list:
            var_list[name] += 1
        # Deal with the variables in the current layer
        for v in model.layers[layer_id].trainable_variables:
            if v.name not in var_list:
                var_list[v.name] = 0

        if layer_id == prev_layer - 1 and prev_model:
            checkpoint.restore(prev_model).assert_existing_objects_matched()
            logging.info('Checkpoint restored from %s.', prev_model)
            model.compile(optimizer=tf.keras.optimizers.Adam(_BASE_LR),
                          loss=loss,
                          metrics=metrics_compile)
            metrics = model.evaluate(x=val_dataset, verbose=2)
            val_metric = metrics[1]
            with summary_writer.as_default():
                for value, key in zip(metrics, model.metrics_names):
                    tf.summary.scalar(key, value, layer_id + 1)
            continue
        elif layer_id < prev_layer:
            logging.info('Skip layer %d.', layer_id + 1)
            continue

        logging.info('Compiling model.')

        model.compile(optimizer=utils.Adam(var_list,
                                           True,
                                           learning_rate=_BASE_LR),
                      loss=loss,
                      metrics=metrics_compile)

        earlystopping_cb = tf.keras.callbacks.EarlyStopping(
            monitor=monitor,
            min_delta=0,
            patience=5,
            mode='min',
            restore_best_weights=False)
        cbs = [earlystopping_cb]

        logging.info('Fitting Lista Keras model.')
        model.fit(train_dataset,
                  epochs=epochs,
                  steps_per_epoch=training_steps_per_epoch,
                  callbacks=cbs,
                  validation_data=val_dataset,
                  validation_steps=validation_steps_per_epoch,
                  verbose=2)
        logging.info('Finished fitting Lista Keras model.')
        model.summary()

        for i in range(2):
            logging.info('Compiling model.')
            model.compile(optimizer=utils.Adam(var_list,
                                               learning_rate=_BASE_LR * 0.2 *
                                               0.1**i),
                          loss=loss,
                          metrics=metrics_compile)

            earlystopping_cb = tf.keras.callbacks.EarlyStopping(
                monitor=monitor,
                min_delta=0,
                patience=5,
                mode='min',
                restore_best_weights=False)
            cbs = [earlystopping_cb]

            logging.info('Fitting Lista Keras model.')
            history = model.fit(train_dataset,
                                epochs=epochs,
                                steps_per_epoch=training_steps_per_epoch,
                                callbacks=cbs,
                                validation_data=val_dataset,
                                validation_steps=validation_steps_per_epoch,
                                verbose=2)
            logging.info('Finished fitting Lista Keras model.')
        model.summary()
        val_metric = history.history[monitor][-1]
        with summary_writer.as_default():
            for key in history.history.keys():
                tf.summary.scalar(key, history.history[key][-1], layer_id + 1)
        try:
            checkpoint.save(utils.save_partial(model_dir, layer_id))
            logging.info('Checkpoint saved at %s',
                         utils.save_partial(model_dir, layer_id))
        except tf.errors.NotFoundError:
            pass

    if task == 'cs':
        raise NotImplementedError(
            'Compressive sensing testing part not implemented yet')
        data = np.load(os.path.join(data_dir, 'set11.npy'), allow_pickle=True)
        phi = np.load(os.path.join(data_dir, 'phi_128_256.npy'),
                      allow_pickle=True).astype(np.float32)
        psnr = utils.PSNR()
        for im in data:
            im_ = im.astype(np.float32)
            cols = utils.im2cols(im_)
            patch_mean = np.mean(cols, axis=1, keepdims=True)
            fs = ((cols - patch_mean) / 255.0).astype(np.float32)
            ys = np.matmul(fs, phi.transpose())
            fs_rec = model.predict_on_batch(ys)[:, -N:]
            cols_rec = fs_rec * 255.0 + patch_mean
            im_rec = utils.col2im(cols_rec).astype(np.float32)
            psnr.update_state(im_, im_rec)
        logging.info('Test PSNR: %f', psnr.result().numpy())
        val_metric = float(psnr.result().numpy())
        with summary_writer.as_default():
            tf.summary.scalar('test_psnr', psnr.result().numpy(), 0)

    return val_metric