Пример #1
0
    def plot(self, model_name, metric_name, metric_value):
        if model_name not in self._writers:
            self._writers[model_name] = TensorboardLogger(
                os.path.join(self._log_dir, model_name))
        if model_name not in self._steps:
            self._steps[model_name] = Counter()

        self._writers[model_name].log_value(
            metric_name,
            metric_value,
            step=self._steps[model_name][metric_name])
        self._steps[model_name][metric_name] += 1

        serialize(self._steps_path, self._steps)
Пример #2
0
    def __init__(self, config: Configuration):
        self.config = config
        self.tensorboard_logger = TensorboardLogger(
            self.config.output_directory)
        self.buffer = PrioritizedBuffer(
            capacity=config.replay_capacity,
            epsilon=config.replay_min_priority,
            alpha=config.replay_prioritization_factor,
            max_priority=config.replay_max_priority,
        )
        self.beta = config.replay_importance_weight

        self.stats = LearnerStatistics(self.config, self.tensorboard_logger,
                                       self.buffer)
        learner_address = config.learner_ip_address + ":" + config.starting_port
        self._connect_sockets(learner_address)
Пример #3
0
 def __init__(self, opt):
     self.exp_name = opt['name']
     self.use_tb_logger = opt['use_tb_logger']
     self.opt = opt['logger']
     self.log_dir = opt['path']['log']
     # loss log file
     self.loss_log_path = os.path.join(self.log_dir, 'loss_log.txt')
     with open(self.loss_log_path, "a") as log_file:
         log_file.write('=============== Time: ' + get_timestamp() + ' =============\n')
         log_file.write('================ Training Losses ================\n')
     # val results log file
     self.val_log_path = os.path.join(self.log_dir, 'val_log.txt')
     with open(self.val_log_path, "a") as log_file:
         log_file.write('================ Time: ' + get_timestamp() + ' ===============\n')
         log_file.write('================ Validation Results ================\n')
     if self.use_tb_logger and 'debug' not in self.exp_name:
         from tensorboard_logger import Logger as TensorboardLogger
         self.tb_logger = TensorboardLogger('../tb_logger/' + self.exp_name)
Пример #4
0
def main():

    args = get_args()
    config = Configuration("./apex/config.json")

    tensorboard_logger = TensorboardLogger(config.output_directory,
                                           args.actor_index)
    actor = Actor(config, args.actor_index, args.starting_port,
                  tensorboard_logger)

    enemy_agents = []
    for _ in range(config.snakes - 1):
        enemy_agents.append(EnemyActor(actor))

    env = BattlesnakeEnvironment(
        config,
        enemy_agents=enemy_agents,
        output_directory=f"{config.output_directory}/actor-{args.actor_index}",
        actor_idx=args.actor_index,
        tensorboard_logger=tensorboard_logger,
    )

    wait_for_initial_parameters(actor)

    while True:
        state = env.reset()
        terminal = False
        while not terminal:
            if env.stats.steps > config.random_initial_steps:
                action, greedy = actor.act(state)
            else:
                action = np.random.choice(3)
                greedy = False
            next_state, reward, terminal = env.step(action)
            actor.observe(
                Observation(state, action, reward, next_state,
                            config.discount_factor, greedy))
            state = next_state
        if env.stats.episodes % config.parameter_update_interval == 0:
            actor.update_parameters()
        if env.stats.episodes % config.report_interval == 0:
            env.stats.report()
        if env.stats.episodes % (config.render_interval) == 0:
            env.render()
Пример #5
0
 def __init__(self, opt, tb_logger_suffix=''):
     self.exp_name = opt['name']
     self.use_tb_logger = opt['use_tb_logger']
     self.opt = opt['logger']
     self.log_dir = opt['path']['log']
     if not os.path.isdir(self.log_dir):
         os.mkdir(self.log_dir)
     # loss log file
     self.loss_log_path = os.path.join(self.log_dir, 'loss_log.txt')
     with open(self.loss_log_path, 'a') as log_file:
         log_file.write('=============== Time: ' + get_timestamp() +
                        ' =============\n')
         log_file.write(
             '================ Training Losses ================\n')
     # val results log file
     self.val_log_path = os.path.join(self.log_dir, 'val_log.txt')
     with open(self.val_log_path, 'a') as log_file:
         log_file.write('================ Time: ' + get_timestamp() +
                        ' ===============\n')
         log_file.write(
             '================ Validation Results ================\n')
     if self.use_tb_logger:  # and 'debug' not in self.exp_name:
         from tensorboard_logger import Logger as TensorboardLogger
         logger_dir_num = 0
         tb_logger_dir = self.log_dir.replace('experiments', 'logs')
         if not os.path.isdir(tb_logger_dir):
             os.mkdir(tb_logger_dir)
         existing_dirs = sorted([
             dir.split('_')[0] for dir in os.listdir(tb_logger_dir)
             if os.path.isdir(os.path.join(tb_logger_dir, dir))
         ],
                                key=lambda x: int(x.split('_')[0]))
         if len(existing_dirs) > 0:
             logger_dir_num = int(existing_dirs[-1]) + 1
         self.tb_logger = TensorboardLogger(
             os.path.join(tb_logger_dir,
                          str(logger_dir_num) + tb_logger_suffix))
Пример #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', '-n', type=str)
    args = parser.parse_args()
    experiment_name = args.name

    HYPARAMS = load_json(
        './hyparams/nec_hyparams.json')[experiment_name]['hyparams']
    logger.debug('experiment_name: {} hyparams: {}'.format(
        experiment_name, HYPARAMS))
    # make checkpoint path
    experiment_logdir = 'experiments/{}'.format(experiment_name)
    if not os.path.exists(experiment_logdir):
        os.makedirs(experiment_logdir)

    # write to tensorboard
    tensorboard_logdir = '{}/tensorboard'.format(experiment_logdir)
    if not os.path.exists(tensorboard_logdir):
        os.mkdir(tensorboard_logdir)
    writer = TensorboardLogger(logdir=tensorboard_logdir)

    env = gym.make('CartPole-v0')
    agent = NECAgent(input_dim=env.observation_space.shape[0],
                     encode_dim=32,
                     hidden_dim=64,
                     output_dim=env.action_space.n,
                     capacity=HYPARAMS['capacity'],
                     buffer_size=HYPARAMS['buffer_size'],
                     epsilon_start=HYPARAMS['epsilon_start'],
                     epsilon_end=HYPARAMS['epsilon_end'],
                     decay_factor=HYPARAMS['decay_factor'],
                     lr=HYPARAMS['lr'],
                     p=HYPARAMS['p'],
                     similarity_threshold=HYPARAMS['similarity_threshold'],
                     alpha=HYPARAMS['alpha'])
    global_steps = 0
    for episode in range(HYPARAMS['episodes']):
        state = env.reset()
        counter = 0
        while True:
            n_steps_q = 0
            start_state = state
            # N-steps Q estimate
            for step in range(HYPARAMS['horizon']):
                state_tensor = torch.from_numpy(state).float().unsqueeze(0)
                action_tensor, value_tensor, encoded_state_tensor = agent.epsilon_greedy_infer(
                    state_tensor)
                if step == 0:
                    start_action = action_tensor.item()
                    start_encoded_state = encoded_state_tensor
                # env.render()
                if global_steps > HYPARAMS['warmup_steps']:
                    action = action_tensor.item()
                    agent.epsilon_decay()
                else:
                    action = env.action_space.sample()
                logger.debug(
                    'episode: {} global_steps: {} value: {} action: {} state: {} epsilon: {}'
                    .format(episode, global_steps, value_tensor.item(), action,
                            state, agent.epsilon))
                next_state, reward, done, info = env.step(action)
                counter += 1
                global_steps += 1
                writer.log_training_v2(global_steps, {
                    'train/value': value_tensor.item(),
                })
                n_steps_q += (HYPARAMS['gamma']**step) * reward
                if done:
                    break
                state = next_state
            n_steps_q += (HYPARAMS['gamma']**HYPARAMS['horizon']
                          ) * agent.get_target_n_steps_q().item()
            writer.log_training_v2(global_steps, {
                'sampled/n_steps_q': n_steps_q,
            })
            logger.debug('sample n_steps_q: {}'.format(n_steps_q))
            # append to ReplayBuffer and DND
            agent.remember_to_replay_buffer(start_state, start_action,
                                            n_steps_q)
            agent.remember_to_dnd(start_encoded_state, start_action, n_steps_q)

            if global_steps / HYPARAMS['horizon'] > HYPARAMS['batch_size']:
                agent.replay(batch_size=HYPARAMS['batch_size'])
            if done:
                # update dnd
                writer.log_episode(episode + 1, counter)
                logger.info('episode done! episode: {} score: {}'.format(
                    episode, counter))
                logger.debug('dnd[0] len: {}'.format(len(agent.dnd_list[0])))
                logger.debug('dnd[1] len: {}'.format(len(agent.dnd_list[1])))
                break
Пример #7
0
def main():
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=100,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--frequency', '-f', type=int, default=-1,
                        help='Frequency of taking a snapshot')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                        help='Number of units')
    parser.add_argument('--noplot', dest='plot', action='store_false',
                        help='Disable PlotReport extension')
    parser.add_argument('--log-dir', default=None, 
                        help='directory to output TensorBoard event file (default: runs/<DATETIME>)')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# unit: {}'.format(args.unit))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    if args.log_dir is None:
        args.log_dir = os.path.join('runs', datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(log_dir=args.log_dir)

    # Set up a neural network to train
    # Classifier reports softmax cross entropy loss and accuracy at every
    # iteration, which will be used by the PrintReport extension below.
    model = L.Classifier(MLP(args.unit, 10))
    if args.gpu >= 0:
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()  # Copy the model to the GPU

    # Setup an optimizer
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    # Load the MNIST dataset
    train, test = chainer.datasets.get_mnist()

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                 repeat=False, shuffle=False)

    # Set up a trainer
    updater = training.StandardUpdater(
        train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))

    # Dump a computational graph from 'loss' variable at the first iteration
    # The "main" refers to the target link of the "main" optimizer.
    trainer.extend(extensions.dump_graph('main/loss'))

    # Take a snapshot for each specified epoch
    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport())

    # Save two plot images to the result dir
    if args.plot and extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch', file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch', file_name='accuracy.png'))

    # Print selected entries of the log to stdout
    # Here "main" refers to the target link of the "main" optimizer again, and
    # "validation" refers to the default name of the Evaluator extension.
    # Entries other than 'epoch' are reported by the Classifier link, called by
    # either the updater or the evaluator.
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar())

    # Write training log to TensorBoard log file
    trainer.extend(TensorboardLogger(writer))

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()
Пример #8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--name',
                        '-n',
                        required=True,
                        type=str,
                        help='name of experiment')
    parser.add_argument('--render', action='store_true', help='render gym')
    args = parser.parse_args()

    experiment_name = args.name
    is_render = args.render

    hyparams = load_json(
        './hyparams/dqn_hyparams.json')[experiment_name]['hyparams']
    # make checkpoint path
    experiment_logdir = 'experiments/{}'.format(experiment_name)
    if not os.path.exists(experiment_logdir):
        os.makedirs(experiment_logdir)

    # hyparameters
    lr = hyparams['lr']
    buffer_size = hyparams['buffer_size']
    gamma = hyparams['gamma']
    epsilon_start = hyparams['epsilon_start']
    epsilon_end = hyparams['epsilon_end']
    decay_factor = hyparams['decay_factor']
    batch_size = hyparams['batch_size']
    replay_freq = hyparams['replay_freq']
    target_update_freq = hyparams['target_update_freq']
    episodes = hyparams['episodes']
    warmup_steps = hyparams['warmup_steps']
    # max_steps = 1e10
    logger.debug('experiment_name: {} hyparams: {}'.format(
        experiment_name, hyparams))

    # write to tensorboard
    tensorboard_logdir = '{}/tensorboard'.format(experiment_logdir)
    if not os.path.exists(tensorboard_logdir):
        os.mkdir(tensorboard_logdir)
    writer = TensorboardLogger(logdir=tensorboard_logdir)

    env = gym.make('CartPole-v0')
    env.reset()
    # logger.debug('observation_space.shape: {}'.format(env.observation_space.shape))
    agent = DQNAgent(buffer_size,
                     writer=writer,
                     input_dim=env.observation_space.shape[0],
                     output_dim=env.action_space.n,
                     gamma=gamma,
                     epsilon_start=epsilon_start,
                     epsilon_end=epsilon_end,
                     decay_factor=decay_factor)

    state, _, _, _ = env.step(
        env.action_space.sample())  # take a random action to start with
    writer.add_graph(
        agent.policy_network,
        torch.tensor([state],
                     dtype=torch.float32))  # add model graph to tensorboard
    # state, reward, done, info = env.step(env.action_space.sample()) # take a random action to start with
    # for i in range(50):
    #     agent.remember(state, reward, env.action_space.sample(), state, False)
    # for i in range(50):
    #     agent.remember(state, reward, env.action_space.sample(), state, True)
    # loss = agent.replay(batch_size=5)
    global_steps = 0
    for episode in range(episodes):
        score = 0.0
        total_loss = 0.0
        env.reset()
        logger.debug('env.reset() episode {} starts!'.format(episode))
        # update target_network
        if episode % target_update_freq == 0:
            # 1. test replay_bufer
            # logger.debug('step: {} number of samples in bufer: {} sample: {}'.format(step, len(agent.replay_buffer), agent.replay_buffer.get_batch(2)))
            agent.update_target_network()
        for step in count():
            if is_render:
                env.render()
            action_tensor, value_tensor = agent.epsilon_greedy_infer(
                torch.tensor([state], dtype=torch.float32))
            target_value_tensor = agent.evaluate_state(
                torch.tensor([state], dtype=torch.float32))  # temp: for debug
            next_state, reward, done, info = env.step(
                action_tensor.item())  # take a random action
            # action = env.action_space.sample()
            # next_state, reward, done, info = env.step(action) # take a random action
            # logger.debug('episode: {} state: {} reward: {} action: {} next_state: {} done: {}'.format(episode, state, reward, action, next_state, done))
            agent.remember(state, reward, action_tensor.item(), next_state,
                           done)
            # 2. test QNetwork
            # logger.debug('state_tensor: {} action_tensor: {} value_tensor: {}'.format(state_tensor, action_tensor, value_tensor))
            # logger.debug('state_tensor: {} action: {} value: {}'.format(state_tensor, action_tensor.item(), value_tensor.item()))
            # print('state: {} reward: {} action_tensor.item(): {} next_state: {} done: {}'.format(state, reward, action_tensor.item(), next_state, done))
            score += reward
            # experience replay
            if global_steps > max(
                    batch_size,
                    warmup_steps) and global_steps % replay_freq == 0:
                loss = agent.replay(batch_size)
                total_loss += loss
                logger.debug(
                    'episode: {} done: {} global_steps: {} loss: {}'.format(
                        episode, done, global_steps, loss))
                writer.log_training(global_steps, loss, agent.lr,
                                    value_tensor.item(),
                                    target_value_tensor.item())
            writer.add_scalar('epsilon', agent.epsilon, global_steps)  # FIXME

            # if global_steps > max(batch_size, warmup_steps) and global_steps % 1000:
            #     writer.log_linear_weights(global_steps, 'encoder.0.weight', agent.policy_network.get_weights()['encoder.0.weight'])
            agent.epsilon_decay()
            state = next_state  # update state manually
            global_steps += 1
            if done:
                logger.info('episode done! episode: {} score: {}'.format(
                    episode, score))
                writer.log_episode(episode, score, total_loss / (step + 1))
                # save checkpoints
                if global_steps > max(batch_size,
                                      warmup_steps) and episode % 100 == 0:
                    agent.save_checkpoint(experiment_logdir)
                break
                # logger.debug('state_tensor: {} action_tensor: {} value_tensor: {}'.format(state_tensor, action_tensor, value_tensor))
                # logger.debug('output: {} state_tensor: {} state: {}'.format(output, state_tensor, state))
                # agent.remember(state, reward, action, next_state, done)

    env.close()
Пример #9
0
    def train_generator(self, current_loop_num):
        BaseLayer.set_model_parameter_requires_grad_all(self.generator, True)
        BaseLayer.set_model_parameter_requires_grad_all(
            self.discriminator, False)

        # train generator
        # TensorboardLogger.print_parameter(generator)
        for index in range(0, self.opt.generator_train_num):
            train_z = self.Tensor(
                np.random.normal(loc=0,
                                 scale=1,
                                 size=(self.opt.batch_size,
                                       self.opt.latent_dim)))
            fake_imgs, fake_dlatents_out = self.generator(train_z)
            fake_validity = self.discriminator(fake_imgs)

            prob_fake = F.sigmoid(fake_validity).mean()
            TensorboardLogger.write_scalar('prob_fake/generator', prob_fake)
            # print('{} prob_fake(generator): {}'.format(index, prob_fake))

            g_loss = self.generator_loss(fake_validity)
            self.optimizer_g.zero_grad()
            g_loss.backward()
            self.optimizer_g.step()

        run_g_reg = current_loop_num % self.opt.g_reg_interval == 0
        if run_g_reg:
            # generatorの正則化処理
            g_reg_maxcount = 4 if 4 < self.opt.generator_train_num else self.opt.generator_train_num
            for _ in range(0, g_reg_maxcount):
                z = self.Tensor(
                    np.random.normal(loc=0,
                                     scale=1,
                                     size=(self.opt.batch_size,
                                           self.opt.latent_dim)))
                pl_fake_imgs, pl_fake_dlatents_out = self.generator(z)
                g_reg, pl_lenght = self.generator_loss_path_reg(
                    pl_fake_imgs, pl_fake_dlatents_out)
                self.optimizer_g.zero_grad()
                g_reg.backward()
                self.optimizer_g.step()

            TensorboardLogger.write_scalar('loss/g_reg', g_reg)
            TensorboardLogger.write_scalar('loss/path_length', pl_lenght)
            TensorboardLogger.write_scalar(
                'loss/pl_mean_var',
                self.generator_loss_path_reg.pl_mean_var.mean())

        # 推論用のgeneratorに指数移動平均を行った重みを適用する
        Generator.apply_decay_parameters(self.generator,
                                         self.generator_predict,
                                         decay=self.decay)
        fake_imgs_predict, fake_dlatents_out_predict = self.generator_predict(
            train_z)
        fake_predict_validity = self.discriminator(fake_imgs_predict)
        prob_fake_predict = F.sigmoid(fake_predict_validity).mean()
        TensorboardLogger.write_scalar('prob_fake_predict/generator',
                                       prob_fake_predict)
        # print('prob_fake_predict(generator): {}'.format(prob_fake_predict))

        Generator.apply_decay_parameters(self.generator_predict,
                                         self.generator,
                                         decay=self.opt.reverse_decay)

        if current_loop_num % self.opt.save_metrics_interval == 0:
            TensorboardLogger.write_scalar('score/g_score',
                                           fake_validity.mean())
            TensorboardLogger.write_scalar('loss/g_loss', g_loss)
            TensorboardLogger.write_histogram('generator/fake_imgs', fake_imgs)
            TensorboardLogger.write_histogram('generator/fake_dlatents_out',
                                              fake_dlatents_out)
            TensorboardLogger.write_histogram('generator/fake_imgs_predict',
                                              fake_imgs_predict)
            TensorboardLogger.write_histogram(
                'generator/fake_dlatents_out_predict',
                fake_dlatents_out_predict)

        if current_loop_num % self.opt.save_images_tensorboard_interval == 0:
            # for index in range(fake_imgs.shape[0]):
            #     img = adjust_dynamic_range(fake_imgs[index].to('cpu').detach().numpy(), drange_in=[-1, 1], drange_out=[0, 255])
            #     TensorboardLogger.write_image('images/fake/{}'.format(index), img)

            for index in range(fake_imgs_predict.shape[0]):
                img = adjust_dynamic_range(
                    fake_imgs_predict[index].to('cpu').detach().numpy(),
                    drange_in=[-1, 1],
                    drange_out=[0, 255])
                TensorboardLogger.write_image(
                    'images/fake_predict/{}'.format(index), img)

        if current_loop_num % self.opt.save_images_interval == 0:
            # 生成した画像を保存する
            if not os.path.isdir(self.opt.results):
                os.makedirs(self.opt.results, exist_ok=True)
            # fake_imgs_val, fake_dlatents_out_val = generator(val_z)
            # save_image_grid(
            #     # fake_imgs_val.to('cpu').detach().numpy(),
            #     fake_imgs.to('cpu').detach().numpy(),
            #     os.path.join(self.opt.results, '{}_fake.png'.format(TensorboardLogger.global_step)),
            #     batch_size=self.opt.batch_size,
            #     drange=[-1, 1])

            # fake_imgs_predict_val, fake_dlatents_out_predict_val = generator_predict(val_z)
            save_image_grid(fake_imgs_predict.to('cpu').detach().numpy(),
                            os.path.join(
                                self.opt.results, '{}_fake_predict.png'.format(
                                    TensorboardLogger.global_step)),
                            batch_size=self.opt.batch_size,
                            drange=[-1, 1])

        return g_loss
Пример #10
0
 def calculate_fid_score(self):
     fid_score = self.fid.get_score()
     TensorboardLogger.write_scalar('score/fid', fid_score)
Пример #11
0
    def train_discriminator(self, current_loop_num):
        BaseLayer.set_model_parameter_requires_grad_all(self.generator, False)
        BaseLayer.set_model_parameter_requires_grad_all(
            self.discriminator, True)

        # train discriminator
        for index in range(0, self.opt.discriminator_train_num):
            data_iterator = self.dataloader.__iter__()
            imgs = data_iterator.next()
            # imgs = TranformDynamicRange.fade_lod(x=imgs, lod=0.0)
            # imgs = TranformDynamicRange.upscale_lod(x=imgs, lod=0.0)
            real_imgs = Variable(imgs.type(self.Tensor), requires_grad=False)

            z = self.Tensor(
                np.random.normal(loc=0,
                                 scale=1,
                                 size=(self.opt.batch_size,
                                       self.opt.latent_dim)))
            fake_imgs, fake_dlatents_out = self.generator(z)

            real_validity = self.discriminator(real_imgs)
            prob_real = F.sigmoid(real_validity).mean()
            TensorboardLogger.write_scalar('prob_real/discriminator',
                                           prob_real)
            # print('{} prob_real(discriminator): {}'.format(index, prob_real))

            fake_validity = self.discriminator(fake_imgs)
            prob_fake = F.sigmoid(fake_validity).mean()
            TensorboardLogger.write_scalar('prob_fake/discriminator',
                                           prob_fake)
            # print('{} prob_fake(discriminator): {}'.format(index, prob_fake))

            d_loss = self.discriminator_loss(fake_validity, real_validity)
            self.optimizer_d.zero_grad()
            d_loss.backward()
            self.optimizer_d.step()

        run_d_reg = current_loop_num % self.opt.d_reg_interval == 0
        if run_d_reg:
            d_reg_maxcount = 4 if 4 < self.opt.discriminator_train_num else self.opt.discriminator_train_num
            for index in range(0, d_reg_maxcount):
                # discriminatorの正則化処理
                # z = self.Tensor(np.random.normal(loc=0, scale=1, size=(self.opt.batch_size, self.opt.latent_dim)))
                # fake_imgs, fake_dlatents_out = self.generator(z)
                # fake_validity = self.discriminator(fake_imgs)

                real_imgs.requires_grad = True
                real_validity = self.discriminator(real_imgs)

                d_reg = self.discriminator_loss_r1(real_validity, real_imgs)
                self.optimizer_d.zero_grad()
                d_reg.backward()
                self.optimizer_d.step()
            TensorboardLogger.writer.add_scalar(
                '{}/reg/d_reg'.format(TensorboardLogger.now), d_reg,
                TensorboardLogger.global_step)

        if current_loop_num % self.opt.save_metrics_interval == 0:
            TensorboardLogger.write_scalar('score/d_score',
                                           real_validity.mean())
            TensorboardLogger.write_scalar('loss/d_loss', d_loss)
            TensorboardLogger.write_histogram('real_imgs', real_imgs)

        return d_loss
Пример #12
0
def train_lstm(
    model: LstmModel,
    criterion: torch.nn.modules.loss,
    optimizer: torch.optim,
    train_loader: torch.utils.data.dataloader.DataLoader,
    val_loader: torch.utils.data.dataloader.DataLoader,
    device: str,
    verbose: bool,
    n_epochs: int,
    kwargs_writer: Dict[str, str] = None,
) -> NoReturn:
    """Short summary.

    Parameters
    ----------
    model : LstmModel
        Description of parameter `model`.
    criterion : torch.nn.modules.loss
        Description of parameter `criterion`.
    optimizer : torch.optim
        Description of parameter `optimizer`.
    train_loader : torch.utils.data.dataloader.DataLoader
        Description of parameter `train_loader`.
    val_loader : torch.utils.data.dataloader.DataLoader
        Description of parameter `val_loader`.
    device : str
        Description of parameter `device`.
    verbose : bool
        Description of parameter `verbose`.
    n_epochs : int
        Description of parameter `n_epochs`.
    kwargs_writer : Dict[str, str]
        Description of parameter `kwargs_writer`.

    Returns
    -------
    NoReturn
        Description of returned object.

    """

    model = model.to(device)
    dict_loader = {"fit": train_loader, "val": val_loader}

    writer = TensorboardLogger(kwargs_writer)
    global_step_fit = 0
    glob_step_val = 0
    for epoch in tqdm_notebook(range(1, n_epochs + 1)):
        # monitor training loss
        fit_loss = 0.0

        ###################
        # train the model #
        ###################
        epoch_losses = {}
        for phase in ["fit", "val"]:

            for chunk in dict_loader[phase]:

                data = chunk["data"].to(device)
                target = chunk["target"].to(device)
                total_loss = 0

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "fit"):

                    outputs = model(data)

                    loss = criterion(outputs, target)
                    if phase == "fit":
                        loss.backward()
                        optimizer.step()
                        writer.add(
                            fit_loss=loss.item(),
                            val_loss=None,
                            model_for_gradient=None,
                            step=global_step_fit,
                        )
                        global_step_fit += 1

                    else:
                        writer.add(
                            fit_loss=None,
                            val_loss=loss.item(),
                            model_for_gradient=None,
                            step=glob_step_val,
                        )
                        glob_step_val += 1

                total_loss += loss.item() * data.size(0)

            epoch_losses.update({
                f"{phase} loss":
                total_loss / len(dict_loader[phase].dataset)
            })
            # if phase == "fit":
            #     writer.add(fit_loss=None, val_loss=None, model_for_gradient=model)

        # print avg training statistics
        if verbose:
            print(
                f'Fit loss: {epoch_losses["fit loss"]:.4f} and Val loss: {epoch_losses["val loss"]:.4f}'
            )