Exemple #1
0
def setup_solvers(train_params):
    solver_dict = dict()

    solver_generator = S.Adam(alpha=train_params['lr_generator'],
                              beta1=0.5,
                              beta2=0.999)
    with nn.parameter_scope("generator"):
        solver_generator.set_parameters(nn.get_parameters())
    solver_dict["generator"] = solver_generator

    solver_discriminator = S.Adam(train_params['lr_discriminator'],
                                  beta1=0.5,
                                  beta2=0.999)
    with nn.parameter_scope("discriminator"):
        solver_discriminator.set_parameters(nn.get_parameters())
    solver_dict["discriminator"] = solver_discriminator

    solver_kp_detector = S.Adam(train_params['lr_kp_detector'],
                                beta1=0.5,
                                beta2=0.999)
    with nn.parameter_scope("kp_detector"):
        solver_kp_detector.set_parameters(nn.get_parameters())

    solver_dict["kp_detector"] = solver_kp_detector
    return solver_dict
Exemple #2
0
    def _build(self):
        # inference
        self.infer_obs_t = nn.Variable((1,) + self.obs_shape)

        with nn.parameter_scope('trainable'):
            self.infer_policy_t = policy_network(self.infer_obs_t,
                                                 self.action_size, 'actor')

        # training
        self.obss_t = nn.Variable((self.batch_size,) + self.obs_shape)
        self.acts_t = nn.Variable((self.batch_size, self.action_size))
        self.rews_tp1 = nn.Variable((self.batch_size, 1))
        self.obss_tp1 = nn.Variable((self.batch_size,) + self.obs_shape)
        self.ters_tp1 = nn.Variable((self.batch_size, 1))

        # critic training
        with nn.parameter_scope('trainable'):
            q_t = q_network(self.obss_t, self.acts_t, 'critic')
        with nn.parameter_scope('target'):
            policy_tp1 = policy_network(self.obss_tp1, self.action_size,
                                        'actor')
            q_tp1 = q_network(self.obss_tp1, policy_tp1, 'critic')
        y = self.rews_tp1 + self.gamma * q_tp1 * (1.0 - self.ters_tp1)
        self.critic_loss = F.mean(F.squared_error(q_t, y))

        # actor training
        with nn.parameter_scope('trainable'):
            policy_t = policy_network(self.obss_t, self.action_size, 'actor')
            q_t_with_actor = q_network(self.obss_t, policy_t, 'critic')
        self.actor_loss = -F.mean(q_t_with_actor)

        # get neural network parameters
        with nn.parameter_scope('trainable'):
            with nn.parameter_scope('critic'):
                critic_params = nn.get_parameters()
            with nn.parameter_scope('actor'):
                actor_params = nn.get_parameters()

        # setup optimizers
        self.critic_solver = S.Adam(self.critic_lr)
        self.critic_solver.set_parameters(critic_params)
        self.actor_solver = S.Adam(self.actor_lr)
        self.actor_solver.set_parameters(actor_params)

        with nn.parameter_scope('trainable'):
            trainable_params = nn.get_parameters()
        with nn.parameter_scope('target'):
            target_params = nn.get_parameters()

        # build target update
        update_targets = []
        sync_targets = []
        for key, src in trainable_params.items():
            dst = target_params[key]
            updated_dst = (1.0 - self.tau) * dst + self.tau * src
            update_targets.append(F.assign(dst, updated_dst))
            sync_targets.append(F.assign(dst, src))
        self.update_target_expr = F.sink(*update_targets)
        self.sync_target_expr = F.sink(*sync_targets)
Exemple #3
0
def main():
    # argparse
    args = get_args()

    # Context Setting
    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id)
    nn.set_default_context(ctx)

    model_path = args.model

    if args.train:
        # Data Loading
        logger.info("Initialing DataSource.")
        train_iterator = facade.facade_data_iterator(
            args.traindir,
            args.batchsize,
            shuffle=True,
            with_memory_cache=False)
        val_iterator = facade.facade_data_iterator(
            args.valdir,
            args.batchsize,
            random_crop=False,
            shuffle=False,
            with_memory_cache=False)

        monitor = nm.Monitor(args.logdir)
        solver_gen = S.Adam(alpha=args.lrate, beta1=args.beta1)
        solver_dis = S.Adam(alpha=args.lrate, beta1=args.beta1)

        generator = unet.generator
        discriminator = unet.discriminator

        model_path = train(generator, discriminator, args.patch_gan,
                           solver_gen, solver_dis,
                           args.weight_l1, train_iterator, val_iterator,
                           args.epoch, monitor, args.monitor_interval)

    if args.generate:
        if model_path is not None:
            # Data Loading
            logger.info("Generating from DataSource.")
            test_iterator = facade.facade_data_iterator(
                args.testdir,
                args.batchsize,
                shuffle=False,
                with_memory_cache=False)
            generator = unet.generator
            generate(generator, model_path, test_iterator, args.logdir)
        else:
            logger.error("Trained model was NOT given.")
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default=None, type=str)
    parser.add_argument('--info', default=None, type=str)
    args = parser.parse_args()

    config = load_decoder_config(args.config)
    if args.info:
        config["experiment_name"] += args.info

    pprint.pprint(config)

    #########################
    # Context Setting
    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info(f'Running in {config["context"]}.')
    ctx = get_extension_context(config["context"],
                                device_id=config["device_id"])
    nn.set_default_context(ctx)
    #########################

    # Data Loading
    logger.info('Initialing Datasource')
    train_iterator = data.celebv_data_iterator(
        dataset_mode="decoder",
        celeb_name=config["trg_celeb_name"],
        data_dir=config["train_dir"],
        ref_dir=config["ref_dir"],
        mode=config["mode"],
        batch_size=config["train"]["batch_size"],
        shuffle=config["train"]["shuffle"],
        with_memory_cache=config["train"]["with_memory_cache"],
        with_file_cache=config["train"]["with_file_cache"],
    )

    monitor = nm.Monitor(
        os.path.join(config["logdir"], "decoder", config["trg_celeb_name"],
                     config["experiment_name"]))
    # Optimizer
    solver_netG = S.Adam(alpha=config["train"]["lr"],
                         beta1=config["train"]["beta1"])
    solver_netD = S.Adam(alpha=config["train"]["lr"],
                         beta1=config["train"]["beta1"])

    # Network
    netG = models.netG_decoder
    netD = models.netD_decoder

    train(config, netG, netD, solver_netG, solver_netD, train_iterator,
          monitor)
Exemple #5
0
def create_network(batch_size, num_dilations, learning_rate):
    # model
    x = nn.Variable(shape=(batch_size, data_config.duration, 1))  # (B, T, 1)
    onehot = F.one_hot(x, shape=(data_config.q_bit_len, ))  # (B, T, C)
    wavenet_input = F.transpose(onehot, (0, 2, 1))  # (B, C, T)

    # speaker embedding
    s_emb = None

    net = WaveNet(num_dilations)
    wavenet_output = net(wavenet_input, s_emb)

    pred = F.transpose(wavenet_output, (0, 2, 1))

    # (B, T, 1)
    t = nn.Variable(shape=(batch_size, data_config.duration, 1))

    loss = F.mean(F.softmax_cross_entropy(pred, t))
    # loss.visit(PrintFunc())

    # Create Solver.
    solver = S.Adam(learning_rate)
    solver.set_parameters(nn.get_parameters())

    return x, t, loss, solver
Exemple #6
0
def main():
    # 入力データのshape定義
    x = nn.variable.Variable(
        [BATCH_SIZE, IMAGE_DEPTH * IMAGE_WIDTH * IMAGE_HEIGHT])
    # ラベルのshape定義
    t = nn.variable.Variable([BATCH_SIZE, LABEL_NUM])

    pred = convolution(x)
    loss_ = loss(pred, t)

    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())

    data = InputData()

    for i in range(NUM_STEP):
        # 100STEP毎にテスト実施
        if i % 100 == 0:
            l = 0
            a = 0
            for k, (t.d, x.d) in enumerate(data.test_data()):
                loss_.forward()
                l += loss_.d
                a += accuracy(pred, t)
            print("Step: %05d Test loss: %0.05f Test accuracy: %0.05f" %
                  (i, l / k, a / k))
        t.d, x.d = data.next_batch()
        loss_.forward()
        solver.zero_grad()
        loss_.backward()
        solver.weight_decay(DECAY_RATE)
        solver.update()
        if i % 10 == 0:
            print("Step: %05d Train loss: %0.05f Train accuracy: %0.05f" %
                  (i, loss_.d, accuracy(pred, t)))
Exemple #7
0
    def build_train_graph(self, batch):
        self.solver = S.Adam(self.learning_rate)

        obs, action, reward, terminal, newobs = batch
        # Create input variables
        s = nn.Variable(obs.shape)
        a = nn.Variable(action.shape)
        r = nn.Variable(reward.shape)
        t = nn.Variable(terminal.shape)
        snext = nn.Variable(newobs.shape)
        with nn.parameter_scope(self.name_q):
            q = self.q_builder(s, self.num_actions, test=False)
            self.solver.set_parameters(nn.get_parameters())
        with nn.parameter_scope(self.name_qnext):
            qnext = self.q_builder(snext, self.num_actions, test=True)
        qnext.need_grad = False
        clipped_r = F.minimum_scalar(F.maximum_scalar(
            r, -self.clip_reward), self.clip_reward)
        q_a = F.sum(
            q * F.one_hot(F.reshape(a, (-1, 1), inplace=False), (q.shape[1],)), axis=1)
        target = clipped_r + self.gamma * (1 - t) * F.max(qnext, axis=1)
        loss = F.mean(F.huber_loss(q_a, target))
        Variables = namedtuple(
            'Variables', ['s', 'a', 'r', 't', 'snext', 'q', 'loss'])
        self.v = Variables(s, a, r, t, snext, q, loss)
        self.sync_models()
        self.built = True
Exemple #8
0
    def _setup_solvers(self):
        # prepare training parameters
        with nn.parameter_scope('%s/discriminator' % self.scope):
            d_params = nn.get_parameters()
        with nn.parameter_scope('%s/generator' % self.scope):
            g_params = nn.get_parameters()

        # create solver for discriminator
        self.d_lr_scheduler = StepScheduler(self.d_lr, self.gamma,
                                            [self.lr_milestone])
        self.d_solver = S.Adam(self.d_lr, beta1=self.beta1, beta2=0.999)
        self.d_solver.set_parameters(d_params)

        # create solver for generator
        self.g_lr_scheduler = StepScheduler(self.g_lr, self.gamma,
                                            [self.lr_milestone])
        self.g_solver = S.Adam(self.g_lr, beta1=self.beta1, beta2=0.999)
        self.g_solver.set_parameters(g_params)
Exemple #9
0
def train_model(model, data, labels):
    model(data)
    solver = S.Adam(alpha=0.001, beta1=0.9, beta2=0.999, eps=1e-08)
    solver.set_parameters(nn.get_parameters())
    for i in range(1500):
        out = model(data)
        loss = F.categorical_cross_entropy(out, labels)
        loss.forward()
        solver.zero_grad()
        loss.backward()
        solver.update()
Exemple #10
0
    def train(self):
        # variables for training
        tx_in = nn.Variable(
            [self._batch_size, self._x_input_length, self._cols_size])
        tx_out = nn.Variable(
            [self._batch_size, self._x_output_length, self._cols_size])
        tpred = self.network(tx_in, self._lstm_unit_name, self._lstm_units)
        tpred.persistent = True
        loss = F.mean(F.squared_error(tpred, tx_out))
        solver = S.Adam(self._learning_rate)
        solver.set_parameters(nn.get_parameters())

        # variables for validation
        vx_in = nn.Variable(
            [self._batch_size, self._x_input_length, self._cols_size])
        vx_out = nn.Variable(
            [self._batch_size, self._x_output_length, self._cols_size])
        vpred = self.network(vx_in, self._lstm_unit_name, self._lstm_units)

        # data iterators
        tdata = self._load_dataset(self._training_dataset_path,
                                   self._batch_size,
                                   shuffle=True)
        vdata = self._load_dataset(self._validation_dataset_path,
                                   self._batch_size,
                                   shuffle=True)

        # monitors
        from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
        monitor = Monitor(self._monitor_path)
        monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
        monitor_err = MonitorSeries("Training error", monitor, interval=10)
        monitor_time = MonitorTimeElapsed("Training time",
                                          monitor,
                                          interval=100)
        monitor_verr = MonitorSeries("Validation error", monitor, interval=10)

        # Training loop
        for i in range(self._max_iter):
            if i % self._val_interval == 0:
                ve = self._validate(vpred, vx_in, vx_out, vdata,
                                    self._val_iter)
                monitor_verr.add(i, ve / self._val_iter)
            te = self._train(tpred, solver, loss, tx_in, tx_out, tdata.next(),
                             self._weight_decay)
            monitor_loss.add(i, loss.d.copy())
            monitor_err.add(i, te)
            monitor_time.add(i)
        ve = self._validate(vpred, vx_in, vx_out, vdata, self._val_iter)
        monitor_verr.add(i, ve / self._val_iter)

        # Save a best model parameters
        nn.save_parameters(self._model_params_path)
Exemple #11
0
    def __init__(self,
                 batch_size=32,
                 learning_rate=1e-4,
                 max_iter=5086,
                 total_epochs=20,
                 monitor_path=None,
                 val_weight=None,
                 model_load_path=None):
        """
        Construct all the necessary attributes for the attribute classifier.
        Args:
            batch_size (int): number of samples contained in each generated batch
            learning_rate (float) : learning rate
            max_iter (int) : maximum iterations for an epoch
            total_epochs (int) : total epochs to train the model
            val_weight : sample weights
            monitor_path (str) : model parameter to be saved
            model_load_path (str) : load the model
        """
        self.batch_size = batch_size
        # Resnet 50
        # training graph
        model = ResNet50()
        self.input_image = nn.Variable((self.batch_size, ) + model.input_shape)
        self.label = nn.Variable([self.batch_size, 1])
        # fine tuning
        pool = model(self.input_image, training=True, use_up_to='pool')
        self.clf = clf_resnet50(pool)
        self.clf.persistent = True
        # loss
        self.loss = F.mean(F.sigmoid_cross_entropy(self.clf, self.label))
        # hyper parameters
        self.solver = S.Adam(learning_rate)
        self.solver.set_parameters(nn.get_parameters())

        # validation graph
        self.x_v = nn.Variable((self.batch_size, ) + model.input_shape)
        pool_v = model(self.x_v, training=False, use_up_to='pool')
        self.v_clf = clf_resnet50(pool_v, train=False)
        self.v_clf_out = F.sigmoid(self.v_clf)
        self.print_freq = 100
        self.validation_weight = val_weight
        # val params
        self.acc = 0.0
        self.total_epochs = total_epochs
        self.max_iter = max_iter
        self.monitor_path = monitor_path

        if model_load_path is not None:
            _ = nn.load_parameters(model_load_path)
def data_distill(model, uniform_data_iterator, num_iter):
    generated_img = []
    for _ in range(uniform_data_iterator.size //
                   uniform_data_iterator.batch_size):
        img, _ = uniform_data_iterator.next()
        dst_img = nn.Variable(img.shape, need_grad=True)
        dst_img.d = img
        img_params = OrderedDict()
        img_params['img'] = dst_img

        init_lr = 0.5
        solver = S.Adam(alpha=init_lr)
        solver.set_parameters(img_params)
        #scheduler = lr_scheduler.CosineScheduler(init_lr=0.5, max_iter=num_iter)
        scheduler = ReduceLROnPlateauScheduler(init_lr=init_lr,
                                               min_lr=1e-4,
                                               verbose=False,
                                               patience=100)
        dummy_solver = S.Sgd(lr=0)
        dummy_solver.set_parameters(nn.get_parameters())

        for it in tqdm(range(num_iter)):
            lr = scheduler.get_learning_rate()
            solver.set_learning_rate(lr)

            global outs
            outs = []
            global batch_stats
            batch_stats = []

            y = model(denormalize(dst_img),
                      force_global_pooling=True,
                      training=False)  # denormalize to U(0, 255)
            y.forward(function_post_hook=get_output)
            assert len(outs) == len(batch_stats)
            loss = zeroq_loss(batch_stats, outs, dst_img)
            loss.forward()
            solver.zero_grad()
            dummy_solver.zero_grad()
            loss.backward()
            solver.weight_decay(1e-6)
            solver.update()

            scheduler.update_lr(loss.d)

        generated_img.append(dst_img.d)

    return generated_img
Exemple #13
0
    def __init__(self, generator, args):

        self.generator = generator

        self.solver = S.Adam()
        self.base_lr = 0.1

        self.img_size = 1024
        self.n_latent = 10000
        self.num_iters = 500
        self.latent_dim = self.generator.mapping_network_dim
        self.mse_c = 0.0
        self.n_c = 1e5

        self.lpips_distance = LPIPS(model='vgg')

        self.project(args)
Exemple #14
0
def train(max_iter=60000):
    # Initialize data provider
    di_l = I.data_iterator_mnist(batch_size, True)
    di_t = I.data_iterator_mnist(batch_size, False)

    # Network
    shape_x = (1, 28, 28)
    shape_z = (50, )
    x = nn.Variable((batch_size, ) + shape_x)
    loss_l = I.vae(x, shape_z, test=False)
    loss_t = I.vae(x, shape_z, test=True)

    # Create solver
    solver = S.Adam(learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Monitors for training and validation
    path = cache_dir(os.path.join(I.name, "monitor"))
    monitor = M.Monitor(path)
    monitor_train_loss = M.MonitorSeries("train_loss", monitor, interval=600)
    monitor_val_loss = M.MonitorSeries("val_loss", monitor, interval=600)
    monitor_time = M.MonitorTimeElapsed("time", monitor, interval=600)

    # Training Loop.
    for i in range(max_iter):

        # Initialize gradients
        solver.zero_grad()

        # Forward, backward and update
        x.d, _ = di_l.next()
        loss_l.forward(clear_no_need_grad=True)
        loss_l.backward(clear_buffer=True)
        solver.weight_decay(weight_decay)
        solver.update()

        # Forward for test
        x.d, _ = di_t.next()
        loss_t.forward(clear_no_need_grad=True)

        # Monitor for logging
        monitor_train_loss.add(i, loss_l.d.copy())
        monitor_val_loss.add(i, loss_t.d.copy())
        monitor_time.add(i)

    return path
Exemple #15
0
def main():
    # Context
    ctx = get_extension_context("cudnn", device_id="0")
    nn.set_default_context(ctx)
    nn.auto_forward(False)
    # Inputs
    b, c, h, w = 64, 1, 28, 28
    x = nn.Variable([b, c, h, w])
    t = nn.Variable([b, 1])
    vx = nn.Variable([b, c, h, w])
    vt = nn.Variable([b, 1])
    # Model
    model = Model()
    pred = model(x)
    loss = F.softmax_cross_entropy(pred, t)
    vpred = model(vx, test=True)
    verror = F.top_n_error(vpred, vt)
    # Solver
    solver = S.Adam()
    solver.set_parameters(model.get_parameters(grad_only=True))
    # Data Iterator
    tdi = data_iterator_mnist(b, train=True)
    vdi = data_iterator_mnist(b, train=False)
    # Monitor
    monitor = Monitor("tmp.monitor")
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Training loop
    for e in range(1):
        for j in range(tdi.size // b):
            i = e * tdi.size // b + j
            x.d, t.d = tdi.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            loss.backward(clear_buffer=True)
            solver.update()
            monitor_loss.add(i, loss.d)
        error = 0.0
        for _ in range(vdi.size // b):
            vx.d, vt.d = vdi.next()
            verror.forward(clear_buffer=True)
            error += verror.d
        error /= vdi.size // b
        monitor_verr.add(i, error)
Exemple #16
0
def create_model(net_name: str, batch_size=128, learning_rate=0.001):
    nn.clear_parameters()
    net = namedtuple("net", ("image", "label", "pred", "loss"))
    model = namedtuple("model", ("name", "train", "val", "solver"))
    mnist_cnn_prediction = get_net_prediction(net_name)

    def create_net(test, persistent):
        image = nn.Variable([batch_size, 1, 28, 28])
        label = nn.Variable([batch_size, 1])
        norm = 255 if net_name.startswith("binary") else 1
        pred = mnist_cnn_prediction(image / norm, test=test)
        if persistent:
            pred.persistent = True
        loss = F.mean(F.softmax_cross_entropy(pred, label))
        return net(image, label, pred, loss)

    net_train = create_net(test=False, persistent=True)
    net_val = create_net(test=True, persistent=False)
    solver = S.Adam(learning_rate)
    solver.set_parameters(nn.get_parameters())

    return model(net_name, net_train, net_val, solver)
Exemple #17
0
def train(max_iter=5000, learning_rate=0.001, weight_decay=0):
    train = create_net(False)
    test = create_net(True)

    # ソルバーの作成
    solver = S.Adam(learning_rate)
    solver.set_parameters(nn.get_parameters())

    # モニタの作成
    path = cache_dir(os.path.join(I.name, "monitor"))
    monitor = M.Monitor(path)
    monitor_loss_train = M.MonitorSeries("training_loss", monitor, interval=100)
    monitor_time = M.MonitorTimeElapsed("time", monitor, interval=100)
    monitor_loss_val = M.MonitorSeries("val_loss", monitor, interval=100)

    # 訓練の実行
    for i in range(max_iter):
        if (i + 1) % 100 == 0:
            val_error = 0.0
            val_iter = 10
            for j in range(val_iter):
                test.image0.d, test.image1.d, test.label.d = test.data.next()
                test.loss.forward(clear_buffer=True)
                val_error += test.loss.d
            monitor_loss_val.add(i, val_error / val_iter)
        train.image0.d, train.image1.d, train.label.d = train.data.next()
        solver.zero_grad()
        train.loss.forward(clear_no_need_grad=True)
        train.loss.backward(clear_buffer=True)
        solver.weight_decay(weight_decay)
        solver.update()
        monitor_loss_train.add(i, train.loss.d.copy())
        monitor_time.add(i)

        nn.save_parameters(os.path.join(path, "params.h5"))
        return path
def classification_svd():
    args = get_args()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    mnist_cnn_prediction = mnist_lenet_prediction_slim

    # TRAIN
    reference = "reference"
    slim = "slim"
    rrate = 0.5  # reduction rate
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create `reference` and "slim" prediction graph.
    model_load_path = args.model_load_path
    pred = mnist_cnn_prediction(image, scope=slim, rrate=rrate, test=False)
    pred.persistent = True

    # Decompose and set parameters
    decompose_network_and_set_params(model_load_path, reference, slim, rrate)
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create reference prediction graph.
    vpred = mnist_cnn_prediction(vimage, scope=slim, rrate=rrate, test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    with nn.parameter_scope(slim):
        solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    data = data_iterator_mnist(args.batch_size, True)
    vdata = data_iterator_mnist(args.batch_size, False)
    best_ve = 1.0
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            monitor_verr.add(i, ve / args.val_iter)
        if ve < best_ve:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
            best_ve = ve
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    monitor_verr.add(i, ve / args.val_iter)

    parameter_file = os.path.join(args.model_save_path,
                                  'params_{:06}.h5'.format(args.max_iter))
    nn.save_parameters(parameter_file)
Exemple #19
0
def main(args):
    # Settings
    device_id = args.device_id
    batch_size = args.batch_size
    batch_size_eval = args.batch_size_eval
    n_l_train_data = 4000
    n_train_data = 50000
    n_cls = 10
    learning_rate = 1. * 1e-3
    n_epoch = 300
    act = F.relu
    iter_epoch = n_train_data / batch_size
    n_iter = n_epoch * iter_epoch
    extension_module = args.context
    lambda_ = args.lambda_

    # Model
    ## supervised cnn
    batch_size, m, h, w = batch_size, 3, 32, 32
    ctx = extension_context(extension_module, device_id=device_id)
    x_l = nn.Variable((batch_size, m, h, w))
    y_l = nn.Variable((batch_size, 1))
    pred, log_var = cnn_model_003(ctx, "cnn", x_l)
    one = F.constant(1., log_var.shape)
    loss_ce = ce_loss(ctx, pred, y_l)
    reg_sigma = sigma_regularization(ctx, log_var, one)
    loss_supervised = loss_ce + er_loss(ctx, pred) + lambda_ * reg_sigma

    ## supervised resnet
    pred_res = cifar10_resnet23_prediction(ctx, "resnet", x_l)
    loss_res_ce = ce_loss(ctx, pred_res, y_l)
    loss_res_supervised = loss_res_ce

    ## stochastic regularization
    x_u0 = nn.Variable((batch_size, m, h, w))
    x_u0.persistent = True
    x_u1 = nn.Variable((batch_size, m, h, w))
    pred_x_u0, log_var0 = cnn_model_003(ctx, "cnn", x_u0)
    pred_x_u0.persistent = True
    pred_x_u1, log_var1 = cnn_model_003(ctx, "cnn", x_u1)
    loss_sr = sr_loss_with_uncertainty(ctx, pred_x_u0, pred_x_u1, log_var0,
                                       log_var1)
    reg_sigma0 = sigma_regularization(ctx, log_var0, one)
    reg_sigma1 = sigma_regularization(ctx, log_var1, one)
    reg_sigmas = sigmas_regularization(ctx, log_var0, log_var1)
    loss_unsupervised = loss_sr + er_loss(ctx, pred_x_u0) + er_loss(ctx, pred_x_u1) \
                        + lambda_ * (reg_sigma0 + reg_sigma1) + lambda_ * reg_sigmas

    ## knowledge transfer for resnet
    pred_res_x_u0 = cifar10_resnet23_prediction(ctx, "resnet", x_u0)
    loss_res_unsupervised = kl_divergence(ctx, pred_res_x_u0, pred_x_u0,
                                          log_var0)

    ## evaluate
    batch_size_eval, m, h, w = batch_size, 3, 32, 32
    x_eval = nn.Variable((batch_size_eval, m, h, w))
    x_eval.persistent = True
    pred_eval, _ = cnn_model_003(ctx, "cnn", x_eval, test=True)
    pred_res_eval = cifar10_resnet23_prediction(ctx,
                                                "resnet",
                                                x_eval,
                                                test=True)

    # Solver
    with nn.context_scope(ctx):
        # Solver
        with nn.context_scope(ctx):
            with nn.parameter_scope("cnn"):
                solver = S.Adam(alpha=learning_rate)
                solver.set_parameters(nn.get_parameters())
            with nn.parameter_scope("resnet"):
                solver_res = S.Adam(alpha=learning_rate)
                solver_res.set_parameters(nn.get_parameters())

    # Dataset
    ## separate dataset
    home = os.environ.get("HOME")
    fpath = os.path.join(home, "datasets/cifar10/cifar-10.npz")
    separator = Separator(n_l_train_data)
    separator.separate_then_save(fpath)

    l_train_path = os.path.join(home, "datasets/cifar10/l_cifar-10.npz")
    u_train_path = os.path.join(home, "datasets/cifar10/cifar-10.npz")
    test_path = os.path.join(home, "datasets/cifar10/cifar-10.npz")

    # data reader
    data_reader = Cifar10DataReader(l_train_path,
                                    u_train_path,
                                    test_path,
                                    batch_size=batch_size,
                                    n_cls=n_cls,
                                    da=True,
                                    shape=True)

    # Training loop
    print("# Training loop")
    epoch = 1
    st = time.time()
    acc_prev = 0.
    for i in range(n_iter):
        # Get data and set it to the varaibles
        x_l0_data, x_l1_data, y_l_data = data_reader.get_l_train_batch()
        x_u0_data, x_u1_data, y_u_data = data_reader.get_u_train_batch()

        x_l.d, _, y_l.d = x_l0_data, x_l1_data, y_l_data
        x_u0.d, x_u1.d = x_u0_data, x_u1_data

        # Train cnn
        loss_supervised.forward(clear_no_need_grad=True)
        loss_unsupervised.forward(clear_no_need_grad=True)
        solver.zero_grad()
        loss_supervised.backward(clear_buffer=True)
        loss_unsupervised.backward(clear_buffer=True)
        solver.update()

        # Train resnet
        loss_res_supervised.forward(clear_no_need_grad=True)
        loss_res_unsupervised.forward(clear_no_need_grad=True)
        solver_res.zero_grad()
        loss_res_supervised.backward(clear_buffer=True)
        pred_x_u0.need_grad = False  # no need grad for teacher
        loss_res_unsupervised.backward(clear_buffer=True)
        solver_res.update()
        pred_x_u0.need_grad = True

        # Evaluate
        if (i + 1) % iter_epoch == 0:
            # Get data and set it to the varaibles
            x_data, y_data = data_reader.get_test_batch()

            # Evaluation loop for cnn
            ve = 0.
            iter_val = 0
            for k in range(0, len(x_data), batch_size_eval):
                x_eval.d = get_test_data(x_data, k, batch_size_eval)
                label = get_test_data(y_data, k, batch_size_eval)
                pred_eval.forward(clear_buffer=True)
                ve += categorical_error(pred_eval.d, label)
                iter_val += 1
            msg = "Epoch:{},ElapsedTime:{},Acc:{:02f}".format(
                epoch,
                time.time() - st, (1. - ve / iter_val) * 100)
            print(msg)

            # Evaluation loop for resnet
            ve = 0.
            iter_val = 0
            for k in range(0, len(x_data), batch_size_eval):
                x_eval.d = get_test_data(x_data, k, batch_size_eval)
                label = get_test_data(y_data, k, batch_size_eval)
                pred_res_eval.forward(clear_buffer=True)
                ve += categorical_error(pred_res_eval.d, label)
                iter_val += 1
            msg = "Model:resnet,Epoch:{},ElapsedTime:{},Acc:{:02f}".format(
                epoch,
                time.time() - st, (1. - ve / iter_val) * 100)
            print(msg)

            st = time.time()
            epoch += 1
Exemple #20
0
def train(args):
    if args.c_dim != len(args.selected_attrs):
        print("c_dim must be the same as the num of selected attributes. Modified c_dim.")
        args.c_dim = len(args.selected_attrs)

    # Dump the config information.
    config = dict()
    print("Used config:")
    for k in args.__dir__():
        if not k.startswith("_"):
            config[k] = getattr(args, k)
            print("'{}' : {}".format(k, getattr(args, k)))

    # Prepare Generator and Discriminator based on user config.
    generator = functools.partial(
        model.generator, conv_dim=args.g_conv_dim, c_dim=args.c_dim, num_downsample=args.num_downsample, num_upsample=args.num_upsample, repeat_num=args.g_repeat_num)
    discriminator = functools.partial(model.discriminator, image_size=args.image_size,
                                      conv_dim=args.d_conv_dim, c_dim=args.c_dim, repeat_num=args.d_repeat_num)

    x_real = nn.Variable(
        [args.batch_size, 3, args.image_size, args.image_size])
    label_org = nn.Variable([args.batch_size, args.c_dim, 1, 1])
    label_trg = nn.Variable([args.batch_size, args.c_dim, 1, 1])

    with nn.parameter_scope("dis"):
        dis_real_img, dis_real_cls = discriminator(x_real)

    with nn.parameter_scope("gen"):
        x_fake = generator(x_real, label_trg)
    x_fake.persistent = True  # to retain its value during computation.

    # get an unlinked_variable of x_fake
    x_fake_unlinked = x_fake.get_unlinked_variable()

    with nn.parameter_scope("dis"):
        dis_fake_img, dis_fake_cls = discriminator(x_fake_unlinked)

    # ---------------- Define Loss for Discriminator -----------------
    d_loss_real = (-1) * loss.gan_loss(dis_real_img)
    d_loss_fake = loss.gan_loss(dis_fake_img)
    d_loss_cls = loss.classification_loss(dis_real_cls, label_org)
    d_loss_cls.persistent = True

    # Gradient Penalty.
    alpha = F.rand(shape=(args.batch_size, 1, 1, 1))
    x_hat = F.mul2(alpha, x_real) + \
        F.mul2(F.r_sub_scalar(alpha, 1), x_fake_unlinked)

    with nn.parameter_scope("dis"):
        dis_for_gp, _ = discriminator(x_hat)
    grads = nn.grad([dis_for_gp], [x_hat])

    l2norm = F.sum(grads[0] ** 2.0, axis=(1, 2, 3)) ** 0.5
    d_loss_gp = F.mean((l2norm - 1.0) ** 2.0)

    # total discriminator loss.
    d_loss = d_loss_real + d_loss_fake + args.lambda_cls * \
        d_loss_cls + args.lambda_gp * d_loss_gp

    # ---------------- Define Loss for Generator -----------------
    g_loss_fake = (-1) * loss.gan_loss(dis_fake_img)
    g_loss_cls = loss.classification_loss(dis_fake_cls, label_trg)
    g_loss_cls.persistent = True

    # Reconstruct Images.
    with nn.parameter_scope("gen"):
        x_recon = generator(x_fake_unlinked, label_org)
    x_recon.persistent = True

    g_loss_rec = loss.recon_loss(x_real, x_recon)
    g_loss_rec.persistent = True

    # total generator loss.
    g_loss = g_loss_fake + args.lambda_rec * \
        g_loss_rec + args.lambda_cls * g_loss_cls

    # -------------------- Solver Setup ---------------------
    d_lr = args.d_lr  # initial learning rate for Discriminator
    g_lr = args.g_lr  # initial learning rate for Generator
    solver_dis = S.Adam(alpha=args.d_lr, beta1=args.beta1, beta2=args.beta2)
    solver_gen = S.Adam(alpha=args.g_lr, beta1=args.beta1, beta2=args.beta2)

    # register parameters to each solver.
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())

    # -------------------- Create Monitors --------------------
    monitor = Monitor(args.monitor_path)
    monitor_d_cls_loss = MonitorSeries(
        'real_classification_loss', monitor, args.log_step)
    monitor_g_cls_loss = MonitorSeries(
        'fake_classification_loss', monitor, args.log_step)
    monitor_loss_dis = MonitorSeries(
        'discriminator_loss', monitor, args.log_step)
    monitor_recon_loss = MonitorSeries(
        'reconstruction_loss', monitor, args.log_step)
    monitor_loss_gen = MonitorSeries('generator_loss', monitor, args.log_step)
    monitor_time = MonitorTimeElapsed("Training_time", monitor, args.log_step)

    # -------------------- Prepare / Split Dataset --------------------
    using_attr = args.selected_attrs
    dataset, attr2idx, idx2attr = get_data_dict(args.attr_path, using_attr)
    random.seed(313)  # use fixed seed.
    random.shuffle(dataset)  # shuffle dataset.
    test_dataset = dataset[-2000:]  # extract 2000 images for test

    if args.num_data:
        # Use training data partially.
        training_dataset = dataset[:min(args.num_data, len(dataset) - 2000)]
    else:
        training_dataset = dataset[:-2000]
    print("Use {} images for training.".format(len(training_dataset)))

    # create data iterators.
    load_func = functools.partial(stargan_load_func, dataset=training_dataset,
                                  image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    data_iterator = data_iterator_simple(load_func, len(
        training_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    load_func_test = functools.partial(stargan_load_func, dataset=test_dataset,
                                       image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    test_data_iterator = data_iterator_simple(load_func_test, len(
        test_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    # Keep fixed test images for intermediate translation visualization.
    test_real_ndarray, test_label_ndarray = test_data_iterator.next()
    test_label_ndarray = test_label_ndarray.reshape(
        test_label_ndarray.shape + (1, 1))

    # -------------------- Training Loop --------------------
    one_epoch = data_iterator.size // args.batch_size
    num_max_iter = args.max_epoch * one_epoch

    for i in range(num_max_iter):
        # Get real images and labels.
        real_ndarray, label_ndarray = data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        # Generate target domain labels randomly.
        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        # ---------------- Train Discriminator -----------------
        # generate fake image.
        x_fake.forward(clear_no_need_grad=True)
        d_loss.forward(clear_no_need_grad=True)
        solver_dis.zero_grad()
        d_loss.backward(clear_buffer=True)
        solver_dis.update()

        monitor_loss_dis.add(i, d_loss.d.item())
        monitor_d_cls_loss.add(i, d_loss_cls.d.item())
        monitor_time.add(i)

        # -------------- Train Generator --------------
        if (i + 1) % args.n_critic == 0:
            g_loss.forward(clear_no_need_grad=True)
            solver_dis.zero_grad()
            solver_gen.zero_grad()
            x_fake_unlinked.grad.zero()
            g_loss.backward(clear_buffer=True)
            x_fake.backward(grad=None)
            solver_gen.update()
            monitor_loss_gen.add(i, g_loss.d.item())
            monitor_g_cls_loss.add(i, g_loss_cls.d.item())
            monitor_recon_loss.add(i, g_loss_rec.d.item())
            monitor_time.add(i)

            if (i + 1) % args.sample_step == 0:
                # save image.
                save_results(i, args, x_real, x_fake,
                             label_org, label_trg, x_recon)
                if args.test_during_training:
                    # translate images from test dataset.
                    x_real.d, label_org.d = test_real_ndarray, test_label_ndarray
                    label_trg.d = test_label_ndarray[rand_idx]
                    x_fake.forward(clear_no_need_grad=True)
                    save_results(i, args, x_real, x_fake, label_org,
                                 label_trg, None, is_training=False)

        # Learning rates get decayed
        if (i + 1) > int(0.5 * num_max_iter) and (i + 1) % args.lr_update_step == 0:
            g_lr = max(0, g_lr - (args.lr_update_step *
                                  args.g_lr / float(0.5 * num_max_iter)))
            d_lr = max(0, d_lr - (args.lr_update_step *
                                  args.d_lr / float(0.5 * num_max_iter)))
            solver_gen.set_learning_rate(g_lr)
            solver_dis.set_learning_rate(d_lr)
            print('learning rates decayed, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

    # Save parameters and training config.
    param_name = 'trained_params_{}.h5'.format(
        datetime.datetime.today().strftime("%m%d%H%M"))
    param_path = os.path.join(args.model_save_path, param_name)
    nn.save_parameters(param_path)
    config["pretrained_params"] = param_name

    with open(os.path.join(args.model_save_path, "training_conf_{}.json".format(datetime.datetime.today().strftime("%m%d%H%M"))), "w") as f:
        json.dump(config, f)

    # -------------------- Translation on test dataset --------------------
    for i in range(args.num_test):
        real_ndarray, label_ndarray = test_data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        x_fake.forward(clear_no_need_grad=True)
        save_results(i, args, x_real, x_fake, label_org,
                     label_trg, None, is_training=False)
Exemple #21
0
def train(data_iterator, monitor, config, comm, args):
    monitor_train_loss, monitor_train_recon = None, None
    monitor_val_loss, monitor_val_recon = None, None
    if comm.rank == 0:
        monitor_train_loss = MonitorSeries(
            config['monitor']['train_loss'], monitor, interval=config['train']['logger_step_interval'])
        monitor_train_recon = MonitorImageTile(config['monitor']['train_recon'], monitor, interval=config['train']['logger_step_interval'],
                                               num_images=config['train']['batch_size'])

        monitor_val_loss = MonitorSeries(
            config['monitor']['val_loss'], monitor, interval=config['train']['logger_step_interval'])
        monitor_val_recon = MonitorImageTile(config['monitor']['val_recon'], monitor, interval=config['train']['logger_step_interval'],
                                             num_images=config['train']['batch_size'])

    model = VQVAE(config)

    if not args.sample_from_pixelcnn:
        if config['train']['solver'] == 'adam':
            solver = S.Adam()
        else:
            solver = S.momentum()
        solver.set_learning_rate(config['train']['learning_rate'])

        train_loader = data_iterator(config, comm, train=True)
        if config['dataset']['name'] != 'imagenet':
            val_loader = data_iterator(config, comm, train=False)
        else:
            val_loader = None
    else:
        solver, train_loader, val_loader = None, None, None

    if not args.pixelcnn_prior:
        trainer = VQVAEtrainer(model, solver, train_loader, val_loader, monitor_train_loss,
                               monitor_train_recon, monitor_val_loss, monitor_val_recon, config, comm)
        num_epochs = config['train']['num_epochs']
    else:
        pixelcnn_model = GatedPixelCNN(config['prior'])
        trainer = TrainerPrior(model, pixelcnn_model, solver, train_loader, val_loader, monitor_train_loss,
                               monitor_train_recon, monitor_val_loss, monitor_val_recon, config, comm, eval=args.sample_from_pixelcnn)
        num_epochs = config['prior']['train']['num_epochs']

    if os.path.exists(config['model']['checkpoint']) and (args.load_checkpoint or args.sample_from_pixelcnn):
        checkpoint_path = config['model']['checkpoint'] if not args.pixelcnn_prior else config['prior']['checkpoint']
        trainer.load_checkpoint(checkpoint_path, msg='Parameters loaded from {}'.format(
            checkpoint_path), pixelcnn=args.pixelcnn_prior, load_solver=not args.sample_from_pixelcnn)

    if args.sample_from_pixelcnn:
        trainer.random_generate(
            args.sample_from_pixelcnn, args.sample_save_path)
        return

    for epoch in range(num_epochs):

        trainer.train(epoch)

        if epoch % config['val']['interval'] == 0 and val_loader != None:
            trainer.validate(epoch)

        if comm.rank == 0:
            if epoch % config['train']['save_param_step_interval'] == 0 or epoch == config['train']['num_epochs']-1:
                trainer.save_checkpoint(
                    config['model']['saved_models_dir'], epoch, pixelcnn=args.pixelcnn_prior)
Exemple #22
0
def main(args):
    # Settings
    device_id = args.device_id
    batch_size = 100
    batch_size_eval = 100
    n_l_train_data = 4000
    n_train_data = 50000
    n_cls = 10
    learning_rate = 1. * 1e-3
    n_epoch = 300
    act = F.relu
    iter_epoch = n_train_data / batch_size
    n_iter = n_epoch * iter_epoch
    extension_module = args.context

    # Model
    ## supervised
    batch_size, m, h, w = batch_size, 3, 32, 32
    ctx = extension_context(extension_module, device_id=device_id)
    x_l = nn.Variable((batch_size, m, h, w))
    y_l = nn.Variable((batch_size, 1))
    pred = cnn_model_003(ctx, x_l)
    loss_ce = ce_loss(ctx, pred, y_l)
    loss_er = er_loss(ctx, pred)
    loss_supervised = loss_ce + loss_er

    ## stochastic regularization
    x_u0 = nn.Variable((batch_size, m, h, w))
    x_u1 = nn.Variable((batch_size, m, h, w))
    pred_x_u0 = cnn_model_003(ctx, x_u0)
    pred_x_u1 = cnn_model_003(ctx, x_u1)
    loss_sr = sr_loss(ctx, pred_x_u0, pred_x_u1)
    loss_er0 = er_loss(ctx, pred_x_u0)
    loss_er1 = er_loss(ctx, pred_x_u1)
    loss_unsupervised = loss_sr + loss_er0 + loss_er1

    ## evaluate
    batch_size_eval, m, h, w = batch_size, 3, 32, 32
    x_eval = nn.Variable((batch_size_eval, m, h, w))
    pred_eval = cnn_model_003(ctx, x_eval, test=True)

    # Solver
    with nn.context_scope(ctx):
        solver_superrvised = S.Adam(alpha=learning_rate)
        solver_superrvised.set_parameters(nn.get_parameters())
        solver_unsuperrvised = S.Adam(alpha=learning_rate)
        solver_unsuperrvised.set_parameters(nn.get_parameters())

    # Gradient Scale Container
    gsc = GradScaleContainer(len(nn.get_parameters()))

    # Dataset
    ## separate dataset
    home = os.environ.get("HOME")
    fpath = os.path.join(home, "datasets/cifar10/cifar-10.npz")
    separator = Separator(n_l_train_data)
    separator.separate_then_save(fpath)

    l_train_path = os.path.join(home, "datasets/cifar10/l_cifar-10.npz")
    u_train_path = os.path.join(home, "datasets/cifar10/cifar-10.npz")
    test_path = os.path.join(home, "datasets/cifar10/cifar-10.npz")

    # data reader
    data_reader = Cifar10DataReader(
        l_train_path,
        u_train_path,
        test_path,
        batch_size=batch_size,
        n_cls=n_cls,
        da=True,  #TODO: use F.image_augmentation
        shape=True)

    # Training loop
    print("# Training loop")
    epoch = 1
    st = time.time()
    acc_prev = 0.
    for i in range(n_iter):
        # Get data and set it to the varaibles
        x_l0_data, x_l1_data, y_l_data = data_reader.get_l_train_batch()
        x_u0_data, x_u1_data, y_u_data = data_reader.get_u_train_batch()

        x_l.d, _, y_l.d = x_l0_data, x_l1_data, y_l_data
        x_u0.d, x_u1.d = x_u0_data, x_u1_data

        # Train
        loss_supervised.forward(clear_no_need_grad=True)
        loss_unsupervised.forward(clear_no_need_grad=True)
        ## compute scales and update with grads of supervised loss
        solver_superrvised.zero_grad()
        loss_supervised.backward(clear_buffer=True)
        gsc.set_scales_supervised_loss(nn.get_parameters())
        solver_superrvised.update()
        ## compute scales and update with grads of unsupervised loss
        solver_unsuperrvised.zero_grad()
        loss_unsupervised.backward(clear_buffer=True)
        gsc.set_scales_unsupervised_loss(nn.get_parameters())
        gsc.scale_grad(ctx, nn.get_parameters())
        solver_unsuperrvised.update()

        # Evaluate
        if (i + 1) % iter_epoch == 0:
            # Get data and set it to the varaibles
            x_data, y_data = data_reader.get_test_batch()

            # Evaluation loop
            ve = 0.
            iter_val = 0
            for k in range(0, len(x_data), batch_size_eval):
                x_eval.d = x_data[k:k + batch_size_eval, :]
                label = y_data[k:k + batch_size_eval, :]
                pred_eval.forward(clear_buffer=True)
                ve += categorical_error(pred_eval.d, label)
                iter_val += 1
            msg = "Epoch:{},ElapsedTime:{},Acc:{:02f}".format(
                epoch,
                time.time() - st, (1. - ve / iter_val) * 100)
            print(msg)
            st = time.time()
            epoch += 1
Exemple #23
0
def train(args):
    # Context
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    lambda_ = args.lambda_

    # Model
    # generator loss
    z = nn.Variable([batch_size, latent])
    x_fake = generator(z, maps=maps, up=args.up).apply(persistent=True)
    p_fake = discriminator(x_fake, maps=maps)
    loss_gen = gan_loss(p_fake).apply(persistent=True)
    # discriminator loss
    p_fake = discriminator(x_fake, maps=maps)
    x_real = nn.Variable([batch_size, 3, image_size, image_size])
    p_real = discriminator(x_real, maps=maps)
    loss_dis = gan_loss(p_fake, p_real).apply(persistent=True)
    # gradient penalty
    eps = F.rand(shape=[batch_size, 1, 1, 1])
    x_rmix = eps * x_real + (1.0 - eps) * x_fake
    p_rmix = discriminator(x_rmix, maps=maps)
    x_rmix.need_grad = True  # Enabling gradient computation for double backward
    grads = nn.grad([p_rmix], [x_rmix])
    l2norms = [F.sum(g**2.0, [1, 2, 3])**0.5 for g in grads]
    gp = sum([F.mean((l - 1.0)**2.0) for l in l2norms])
    loss_dis += lambda_ * gp
    # generator with fixed value for test
    z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent))
    x_test = generator(z_test, maps=maps, test=True,
                       up=args.up).apply(persistent=True)

    # Solver
    solver_gen = S.Adam(args.lrg, args.beta1, args.beta2)
    solver_dis = S.Adam(args.lrd, args.beta1, args.beta2)

    with nn.parameter_scope("generator"):
        params_gen = nn.get_parameters()
        solver_gen.set_parameters(params_gen)
    with nn.parameter_scope("discriminator"):
        params_dis = nn.get_parameters()
        solver_dis.set_parameters(params_dis)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10)
    monitor_loss_cri = MonitorSeries("Negative Critic Loss",
                                     monitor,
                                     interval=10)
    monitor_time = MonitorTimeElapsed("Training Time", monitor, interval=10)
    monitor_image_tile_train = MonitorImageTile("Image Tile Train",
                                                monitor,
                                                num_images=batch_size,
                                                interval=1,
                                                normalize_method=denormalize)
    monitor_image_tile_test = MonitorImageTile("Image Tile Test",
                                               monitor,
                                               num_images=batch_size,
                                               interval=1,
                                               normalize_method=denormalize)

    # Data Iterator
    di = data_iterator_cifar10(batch_size, True)

    # Train loop
    for i in range(args.max_iter):
        # Train discriminator
        x_fake.need_grad = False  # no need backward to generator
        for _ in range(args.n_critic):
            solver_dis.zero_grad()
            x_real.d = di.next()[0] / 127.5 - 1.0
            z.d = np.random.randn(batch_size, latent)
            loss_dis.forward(clear_no_need_grad=True)
            loss_dis.backward(clear_buffer=True)
            solver_dis.update()

        # Train generator
        x_fake.need_grad = True  # need backward to generator
        solver_gen.zero_grad()
        z.d = np.random.randn(batch_size, latent)
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        solver_gen.update()
        # Monitor
        monitor_loss_gen.add(i, loss_gen.d)
        monitor_loss_cri.add(i, -loss_dis.d)
        monitor_time.add(i)

        # Save
        if i % args.save_interval == 0:
            monitor_image_tile_train.add(i, x_fake)
            monitor_image_tile_test.add(i, x_test)
            nn.save_parameters(
                os.path.join(args.monitor_path, "params_{}.h5".format(i)))

    # Last
    x_test.forward(clear_buffer=True)
    nn.save_parameters(
        os.path.join(args.monitor_path, "params_{}.h5".format(i)))
    monitor_image_tile_train.add(i, x_fake)
    monitor_image_tile_test.add(i, x_test)
Exemple #24
0
def train():
    args = get_args()

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == "cifar10_resnet23_prediction":
        model_prediction = cifar10_resnet23_prediction

    # TRAIN
    maps = 64
    data_iterator = data_iterator_cifar10
    c = 3
    h = w = 32
    n_train = 50000
    n_valid = 10000

    # Create input variables.
    image = nn.Variable([args.batch_size, c, h, w])
    label = nn.Variable([args.batch_size, 1])
    # Create model_prediction graph.
    pred = model_prediction(image, maps=maps, test=False)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # SSL Regularization
    loss += ssl_regularization(nn.get_parameters(), args.filter_decay,
                               args.channel_decay)

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, c, h, w])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create predition graph.
    vpred = model_prediction(vimage, maps=maps, test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Initialize DataIterator
    data = data_iterator(args.batch_size, True)
    vdata = data_iterator(args.batch_size, False)
    best_ve = 1.0
    ve = 1.0
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(int(n_valid / args.batch_size)):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            ve /= int(n_valid / args.batch_size)
            monitor_verr.add(i, ve)
        if ve < best_ve:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
            best_ve = ve
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(int(n_valid / args.batch_size)):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    ve /= int(n_valid / args.batch_size)
    monitor_verr.add(i, ve)

    parameter_file = os.path.join(args.model_save_path,
                                  'params_{:06}.h5'.format(args.max_iter))
    nn.save_parameters(parameter_file)
Exemple #25
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for MNIST.
    * Construct a computation graph for training and validation.
    * Initialize a solver and set parameter variables to it.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop on the training graph.
      * Compute training error
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
    """
    args = get_args()

    from numpy.random import seed
    seed(0)

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == 'lenet':
        mnist_cnn_prediction = mnist_lenet_prediction
    elif args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction
    else:
        raise ValueError("Unknown network type {}".format(args.net))

    # TRAIN
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    pred = mnist_cnn_prediction(image, test=False, aug=args.augment_train)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create predition graph.
    vpred = mnist_cnn_prediction(vimage, test=True, aug=args.augment_test)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    from numpy.random import RandomState
    data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223))
    vdata = data_iterator_mnist(args.batch_size, False)
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            monitor_verr.add(i, ve / args.val_iter)
        if i % args.model_save_interval == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    monitor_verr.add(i, ve / args.val_iter)

    parameter_file = os.path.join(
        args.model_save_path,
        '{}_params_{:06}.h5'.format(args.net, args.max_iter))
    nn.save_parameters(parameter_file)
Exemple #26
0
def train():
    # Check NNabla version
    if utils.get_nnabla_version_integer() < 11900:
        raise ValueError(
            'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0'
        )

    parser, args = get_train_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    ext = import_extension_module(args.context)

    # Monitors
    # setting up monitors for logging
    monitor_path = args.output
    monitor = Monitor(monitor_path)

    monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1)
    monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_validation_loss = MonitorSeries('Validation loss',
                                            monitor,
                                            interval=1)
    monitor_lr = MonitorSeries('learning rate', monitor, interval=1)
    monitor_time = MonitorTimeElapsed("training time per iteration",
                                      monitor,
                                      interval=1)

    if comm.rank == 0:
        print("Mixing coef. is {}, i.e., MDL = {}*TD-Loss + FD-Loss".format(
            args.mcoef, args.mcoef))
        if not os.path.isdir(args.output):
            os.makedirs(args.output)

    # Initialize DataIterator for MUSDB.
    train_source, valid_source, args = load_datasources(parser, args)

    train_iter = data_iterator(train_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    valid_iter = data_iterator(valid_source,
                               1,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    if comm.n_procs > 1:
        train_iter = train_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

        valid_iter = valid_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

    # Calculate maxiter per GPU device.
    max_iter = int((train_source._size // args.batch_size) // comm.n_procs)
    weight_decay = args.weight_decay * comm.n_procs

    print("max_iter", max_iter)

    # Calculate the statistics (mean and variance) of the dataset
    scaler_mean, scaler_std = utils.get_statistics(args, train_source)

    max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = OpenUnmix_CrossNet(input_mean=scaler_mean,
                               input_scale=scaler_std,
                               nb_channels=args.nb_channels,
                               hidden_size=args.hidden_size,
                               n_fft=args.nfft,
                               n_hop=args.nhop,
                               max_bin=max_bin)

    # Create input variables.
    mixture_audio = nn.Variable([args.batch_size] +
                                list(train_source._get_data(0)[0].shape))
    target_audio = nn.Variable([args.batch_size] +
                               list(train_source._get_data(0)[1].shape))

    vmixture_audio = nn.Variable(
        [1] + [2, valid_source.sample_rate * args.valid_dur])
    vtarget_audio = nn.Variable([1] +
                                [8, valid_source.sample_rate * args.valid_dur])

    # create training graph
    mix_spec, M_hat, pred = unmix(mixture_audio)
    Y = Spectrogram(*STFT(target_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop),
                    mono=(unmix.nb_channels == 1))
    loss_f = mse_loss(mix_spec, M_hat, Y)
    loss_t = sdr_loss(mixture_audio, pred, target_audio)
    loss = args.mcoef * loss_t + loss_f
    loss.persistent = True

    # Create Solver and set parameters.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # create validation graph
    vmix_spec, vM_hat, vpred = unmix(vmixture_audio, test=True)
    vY = Spectrogram(*STFT(vtarget_audio, n_fft=unmix.n_fft,
                           n_hop=unmix.n_hop),
                     mono=(unmix.nb_channels == 1))
    vloss_f = mse_loss(vmix_spec, vM_hat, vY)
    vloss_t = sdr_loss(vmixture_audio, vpred, vtarget_audio)
    vloss = args.mcoef * vloss_t + vloss_f
    vloss.persistent = True

    # Initialize Early Stopping
    es = utils.EarlyStopping(patience=args.patience)

    # Initialize LR Scheduler (ReduceLROnPlateau)
    lr_scheduler = ReduceLROnPlateau(lr=args.lr,
                                     factor=args.lr_decay_gamma,
                                     patience=args.lr_decay_patience)
    best_epoch = 0

    # Training loop.
    for epoch in trange(args.epochs):
        # TRAINING
        losses = utils.AverageMeter()
        for batch in range(max_iter):
            mixture_audio.d, target_audio.d = train_iter.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            if comm.n_procs > 1:
                all_reduce_callback = comm.get_all_reduce_callback()
                loss.backward(clear_buffer=True,
                              communicator_callbacks=all_reduce_callback)
            else:
                loss.backward(clear_buffer=True)
            solver.weight_decay(weight_decay)
            solver.update()
            losses.update(loss.d.copy(), args.batch_size)
        training_loss = losses.avg

        # clear cache memory
        ext.clear_memory_cache()

        # VALIDATION
        vlosses = utils.AverageMeter()
        for batch in range(int(valid_source._size // comm.n_procs)):
            x, y = valid_iter.next()
            dur = int(valid_source.sample_rate * args.valid_dur)
            sp, cnt = 0, 0
            loss_tmp = nn.NdArray()
            loss_tmp.zero()
            while 1:
                vmixture_audio.d = x[Ellipsis, sp:sp + dur]
                vtarget_audio.d = y[Ellipsis, sp:sp + dur]
                vloss.forward(clear_no_need_grad=True)
                cnt += 1
                sp += dur
                loss_tmp += vloss.data
                if x[Ellipsis,
                     sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur:
                    break
            loss_tmp = loss_tmp / cnt
            if comm.n_procs > 1:
                comm.all_reduce(loss_tmp, division=True, inplace=True)
            vlosses.update(loss_tmp.data.copy(), 1)
        validation_loss = vlosses.avg

        # clear cache memory
        ext.clear_memory_cache()

        lr = lr_scheduler.update_lr(validation_loss, epoch=epoch)
        solver.set_learning_rate(lr)
        stop = es.step(validation_loss)

        if comm.rank == 0:
            monitor_best_epoch.add(epoch, best_epoch)
            monitor_traing_loss.add(epoch, training_loss)
            monitor_validation_loss.add(epoch, validation_loss)
            monitor_lr.add(epoch, lr)
            monitor_time.add(epoch)

            if validation_loss == es.best:
                # save best model
                nn.save_parameters(os.path.join(args.output, 'best_xumx.h5'))
                best_epoch = epoch

        if stop:
            print("Apply Early Stopping")
            break
Exemple #27
0
def main(args):
    # Settings
    device_id = args.device_id
    batch_size = args.batch_size
    batch_size_eval = args.batch_size_eval
    n_l_train_data = 4000
    n_train_data = 50000
    n_cls = 10
    learning_rate = 1. * 1e-3
    n_epoch = 300
    act = F.relu
    iter_epoch = n_train_data / batch_size
    n_iter = int(n_epoch * iter_epoch)
    extension_module = args.context
    lambda_ = args.lambda_

    # Model
    ## supervised
    batch_size, m, h, w = batch_size, 3, 32, 32
    ctx = extension_context(extension_module, device_id=device_id)
    x_l = nn.Variable((batch_size, m, h, w))
    y_l = nn.Variable((batch_size, 1))
    pred, log_var = cnn_model_003(ctx, x_l)
    one = F.constant(1., log_var.shape)
    loss_ce = ce_loss(ctx, pred, y_l)
    reg_sigma = sigma_regularization(ctx, log_var, one)
    loss_supervised = loss_ce + er_loss(ctx, pred) + lambda_ * reg_sigma

    ## stochastic regularization
    x_u0 = nn.Variable((batch_size, m, h, w))
    x_u1 = nn.Variable((batch_size, m, h, w))
    pred_x_u0, log_var0 = cnn_model_003(ctx, x_u0)
    pred_x_u1, log_var1 = cnn_model_003(ctx, x_u1)
    loss_sr = sr_loss_with_uncertainty(ctx, pred_x_u0, pred_x_u1, log_var0,
                                       log_var1)
    reg_sigma0 = sigma_regularization(ctx, log_var0, one)
    reg_sigma1 = sigma_regularization(ctx, log_var1, one)
    reg_sigmas = sigmas_regularization(ctx, log_var0, log_var1)
    loss_unsupervised = loss_sr + er_loss(ctx, pred_x_u0) + er_loss(ctx, pred_x_u1) \
                        + lambda_ * (reg_sigma0 + reg_sigma1) + lambda_ * reg_sigmas
    ## evaluate
    batch_size_eval, m, h, w = batch_size, 3, 32, 32
    x_eval = nn.Variable((batch_size_eval, m, h, w))
    pred_eval, _ = cnn_model_003(ctx, x_eval, test=True)

    # Solver
    with nn.context_scope(ctx):
        solver = S.Adam(alpha=learning_rate)
        solver.set_parameters(nn.get_parameters())

    # Dataset
    ## separate dataset
    home = os.environ.get("HOME")
    fpath = os.path.join(home, "datasets/cifar10/cifar-10.npz")
    separator = Separator(n_l_train_data)
    #separator.separate_then_save(fpath)

    l_train_path = os.path.join(home, "datasets/cifar10/l_cifar-10.npz")
    u_train_path = os.path.join(home, "datasets/cifar10/cifar-10.npz")
    test_path = os.path.join(home, "datasets/cifar10/cifar-10.npz")

    # data reader
    data_reader = Cifar10DataReader(l_train_path,
                                    u_train_path,
                                    test_path,
                                    batch_size=batch_size,
                                    n_cls=n_cls,
                                    da=True,
                                    shape=True)

    # Training loop
    print("# Training loop")
    epoch = 1
    st = time.time()
    acc_prev = 0.
    ve_best = 1.
    save_path_prev = ""
    for i in range(n_iter):
        # Get data and set it to the varaibles
        x_l0_data, x_l1_data, y_l_data = data_reader.get_l_train_batch()
        x_u0_data, x_u1_data, y_u_data = data_reader.get_u_train_batch()

        x_l.d, _, y_l.d = x_l0_data, x_l1_data, y_l_data
        x_u0.d, x_u1.d = x_u0_data, x_u1_data

        # Train
        loss_supervised.forward(clear_no_need_grad=True)
        loss_unsupervised.forward(clear_no_need_grad=True)
        solver.zero_grad()
        loss_supervised.backward(clear_buffer=True)
        loss_unsupervised.backward(clear_buffer=True)
        solver.update()

        # Evaluate
        if int((i + 1) % iter_epoch) == 0:
            # Get data and set it to the varaibles
            #x_data, y_data = data_reader.get_test_batch()
            u_train_data = data_reader.u_train_data
            x_data, y_data = u_train_data["train_x"].reshape(
                -1, 3, 32, 32) / 255.0, u_train_data["train_y"],

            # Evaluation loop
            ve = 0.
            iter_val = 0
            for k in range(0, len(x_data), batch_size_eval):
                x_eval.d = get_test_data(x_data, k, batch_size_eval)
                label = get_test_data(y_data, k, batch_size_eval)
                pred_eval.forward(clear_buffer=True)
                ve += categorical_error(pred_eval.d, label)
                iter_val += 1
            ve /= iter_val
            msg = "Epoch:{},ElapsedTime:{},Acc:{:02f}".format(
                epoch,
                time.time() - st, (1. - ve) * 100)
            print(msg)
            if ve < ve_best:
                ve_best = ve
            st = time.time()
            epoch += 1
Exemple #28
0
def train(args):
    # Settings
    b, c, h, w = 1, 3, 256, 256
    beta1 = 0.5
    beta2 = 0.999
    pool_size = 50
    lambda_recon = args.lambda_recon
    lambda_idt = args.lambda_idt
    base_lr = args.learning_rate
    init_method = args.init_method

    # Context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Inputs
    x_raw = nn.Variable([b, c, h, w], need_grad=False)
    y_raw = nn.Variable([b, c, h, w], need_grad=False)
    x_real = image_augmentation(x_raw)
    y_real = image_augmentation(y_raw)
    x_history = nn.Variable([b, c, h, w])
    y_history = nn.Variable([b, c, h, w])
    x_real_test = nn.Variable([b, c, h, w], need_grad=False)
    y_real_test = nn.Variable([b, c, h, w], need_grad=False)

    # Models for training
    # Generate
    y_fake = models.g(x_real, unpool=args.unpool, init_method=init_method)
    x_fake = models.f(y_real, unpool=args.unpool, init_method=init_method)
    y_fake.persistent, x_fake.persistent = True, True
    # Reconstruct
    x_recon = models.f(y_fake, unpool=args.unpool, init_method=init_method)
    y_recon = models.g(x_fake, unpool=args.unpool, init_method=init_method)
    # Discriminate
    d_y_fake = models.d_y(y_fake, init_method=init_method)
    d_x_fake = models.d_x(x_fake, init_method=init_method)
    d_y_real = models.d_y(y_real, init_method=init_method)
    d_x_real = models.d_x(x_real, init_method=init_method)
    d_y_history = models.d_y(y_history, init_method=init_method)
    d_x_history = models.d_x(x_history, init_method=init_method)

    # Models for test
    y_fake_test = models.g(
        x_real_test, unpool=args.unpool, init_method=init_method)
    x_fake_test = models.f(
        y_real_test, unpool=args.unpool, init_method=init_method)
    y_fake_test.persistent, x_fake_test.persistent = True, True
    # Reconstruct
    x_recon_test = models.f(
        y_fake_test, unpool=args.unpool, init_method=init_method)
    y_recon_test = models.g(
        x_fake_test, unpool=args.unpool, init_method=init_method)

    # Losses
    # Reconstruction Loss
    loss_recon = models.recon_loss(x_recon, x_real) \
        + models.recon_loss(y_recon, y_real)
    # Generator loss
    loss_gen = models.lsgan_loss(d_y_fake) \
        + models.lsgan_loss(d_x_fake) \
        + lambda_recon * loss_recon
    # Identity loss
    if lambda_idt != 0:
        logger.info("Identity loss was added.")
        # Identity
        y_idt = models.g(y_real, unpool=args.unpool, init_method=init_method)
        x_idt = models.f(x_real, unpool=args.unpool, init_method=init_method)
        loss_idt = models.recon_loss(x_idt, x_real) \
            + models.recon_loss(y_idt, y_real)
        loss_gen += lambda_recon * lambda_idt * loss_idt
    # Discriminator losses
    loss_dis_y = models.lsgan_loss(d_y_history, d_y_real)
    loss_dis_x = models.lsgan_loss(d_x_history, d_x_real)

    # Solvers
    solver_gen = S.Adam(base_lr, beta1, beta2)
    solver_dis_x = S.Adam(base_lr, beta1, beta2)
    solver_dis_y = S.Adam(base_lr, beta1, beta2)
    with nn.parameter_scope('generator'):
        solver_gen.set_parameters(nn.get_parameters())
    with nn.parameter_scope('discriminator'):
        with nn.parameter_scope("x"):
            solver_dis_x.set_parameters(nn.get_parameters())
        with nn.parameter_scope("y"):
            solver_dis_y.set_parameters(nn.get_parameters())

    # Datasets
    rng = np.random.RandomState(313)
    ds_train_B = cycle_gan_data_source(
        args.dataset, train=True, domain="B", shuffle=True, rng=rng)
    ds_train_A = cycle_gan_data_source(
        args.dataset, train=True, domain="A", shuffle=True, rng=rng)
    ds_test_B = cycle_gan_data_source(
        args.dataset, train=False, domain="B", shuffle=False, rng=rng)
    ds_test_A = cycle_gan_data_source(
        args.dataset, train=False, domain="A", shuffle=False, rng=rng)
    di_train_B = cycle_gan_data_iterator(ds_train_B, args.batch_size)
    di_train_A = cycle_gan_data_iterator(ds_train_A, args.batch_size)
    di_test_B = cycle_gan_data_iterator(ds_test_B, args.batch_size)
    di_test_A = cycle_gan_data_iterator(ds_test_A, args.batch_size)

    # Monitors
    monitor = Monitor(args.monitor_path)

    def make_monitor(name):
        return MonitorSeries(name, monitor, interval=1)
    monitor_loss_gen = make_monitor('generator_loss')
    monitor_loss_dis_x = make_monitor('discriminator_B_domain_loss')
    monitor_loss_dis_y = make_monitor('discriminator_A_domain_loss')

    def make_monitor_image(name):
        return MonitorImage(name, monitor, interval=1,
                            normalize_method=lambda x: (x + 1.0) * 127.5)
    monitor_train_gx = make_monitor_image('fake_images_train_A')
    monitor_train_fy = make_monitor_image('fake_images_train_B')
    monitor_train_x_recon = make_monitor_image('fake_images_B_recon_train')
    monitor_train_y_recon = make_monitor_image('fake_images_A_recon_train')
    monitor_test_gx = make_monitor_image('fake_images_test_A')
    monitor_test_fy = make_monitor_image('fake_images_test_B')
    monitor_test_x_recon = make_monitor_image('fake_images_recon_test_B')
    monitor_test_y_recon = make_monitor_image('fake_images_recon_test_A')
    monitor_train_list = [
        (monitor_train_gx, y_fake),
        (monitor_train_fy, x_fake),
        (monitor_train_x_recon, x_recon),
        (monitor_train_y_recon, y_recon),
        (monitor_loss_gen, loss_gen),
        (monitor_loss_dis_x, loss_dis_x),
        (monitor_loss_dis_y, loss_dis_y),
    ]
    monitor_test_list = [
        (monitor_test_gx, y_fake_test),
        (monitor_test_fy, x_fake_test),
        (monitor_test_x_recon, x_recon_test),
        (monitor_test_y_recon, y_recon_test)]

    # ImagePool
    pool_x = ImagePool(pool_size)
    pool_y = ImagePool(pool_size)

    # Training loop
    epoch = 0
    n_images = np.max([ds_train_B.size, ds_train_A.size]
                      )  # num. images for each domain
    max_iter = args.max_epoch * n_images // args.batch_size
    for i in range(max_iter):
        # Validation
        if int((i+1) % (n_images // args.batch_size)) == 0:
            logger.info("Mode:Test,Epoch:{}".format(epoch))
            # Monitor for train
            for monitor, v in monitor_train_list:
                monitor.add(i, v.d)
            # Use training graph since there are no test mode
            x_data, _ = di_test_B.next()
            y_data, _ = di_test_A.next()
            x_real_test.d = x_data
            y_real_test.d = y_data
            x_recon_test.forward()
            y_recon_test.forward()
            # Monitor for test
            for monitor, v in monitor_test_list:
                monitor.add(i, v.d)
            # Save model
            nn.save_parameters(os.path.join(
                args.model_save_path, 'params_%06d.h5' % i))
            # Learning rate decay
            for solver in [solver_gen, solver_dis_x, solver_dis_y]:
                linear_decay(solver, base_lr, epoch, args.max_epoch)
            epoch += 1

        # Get data
        x_data, _ = di_train_B.next()
        y_data, _ = di_train_A.next()
        x_raw.d = x_data
        y_raw.d = y_data

        # Train Generators
        loss_gen.forward(clear_no_need_grad=False)
        solver_gen.zero_grad()
        loss_gen.backward(clear_buffer=True)
        solver_gen.update()

        # Insert and Get to/from pool
        x_history.d = pool_x.insert_then_get(x_fake.d)
        y_history.d = pool_y.insert_then_get(y_fake.d)

        # Train Discriminator Y
        loss_dis_y.forward(clear_no_need_grad=False)
        solver_dis_y.zero_grad()
        loss_dis_y.backward(clear_buffer=True)
        solver_dis_y.update()

        # Train Discriminator X
        loss_dis_x.forward(clear_no_need_grad=False)
        solver_dis_x.zero_grad()
        loss_dis_x.backward(clear_buffer=True)
        solver_dis_x.update()
Exemple #29
0
        output = F.relu(PF.affine(m, output_mlp_size))
        if train:
            output = F.dropout(output, p=dropout_ratio)
    with nn.parameter_scope('output'):
        y = F.sigmoid(PF.affine(output, 1))

    accuracy = F.mean(F.equal(F.round(y), t))
    loss = F.mean(F.binary_cross_entropy(
        y, t)) + attention_penalty_coef * frobenius(
            F.batch_matmul(a, a, transpose_a=True) - batch_eye(batch_size, r))
    return x, t, accuracy, loss


# Create solver.
x, t, accuracy, loss = build_self_attention_model(train=True)
solver = S.Adam()
solver.set_parameters(nn.get_parameters())

x, t, accuracy, loss = build_self_attention_model(train=True)
trainer = Trainer(inputs=[x, t],
                  loss=loss,
                  metrics={
                      'cross entropy': loss,
                      'accuracy': accuracy
                  },
                  solver=solver)
for epoch in range(max_epoch):
    x, t, accuracy, loss = build_self_attention_model(train=True)
    trainer.update_variables(inputs=[x, t],
                             loss=loss,
                             metrics={
Exemple #30
0
loss_gen = F.mean(
    F.sigmoid_cross_entropy(pred_fake, F.constant(1, pred_fake.shape)))
fake_dis = fake.get_unlinked_variable(need_grad=True)
fake_dis.need_grad = True  # TODO: Workaround until v1.0.2
pred_fake_dis = I.discriminator(fake_dis)
loss_dis = F.mean(
    F.sigmoid_cross_entropy(pred_fake_dis, F.constant(0, pred_fake_dis.shape)))

# Realパスの設定
x = nn.Variable([batch_size, 1, 28, 28])
pred_real = I.discriminator(x)
loss_dis += F.mean(
    F.sigmoid_cross_entropy(pred_real, F.constant(1, pred_real.shape)))

# ソルバーの作成
solver_gen = S.Adam(learning_rate, beta1=0.5)
solver_dis = S.Adam(learning_rate, beta1=0.5)
with nn.parameter_scope("gen"):
    solver_gen.set_parameters(nn.get_parameters())
with nn.parameter_scope("dis"):
    solver_dis.set_parameters(nn.get_parameters())

# パラメータスコープの使い方を見ておく。
print(len(nn.get_parameters()))
with nn.parameter_scope("gen"):
    print(len(nn.get_parameters()))
# パラメータスコープ内では、`get_parameters()`で取得できるパラメータがフィルタリングされ
# る。

# モニターの設定
path = cache_dir(os.path.join(I.name, "monitor"))