Пример #1
0
def main(argv):
  # Make sure tf does not allocate gpu memory.
  tf.config.experimental.set_visible_devices([], 'GPU')
  config = FLAGS.config
  game = config.game + 'NoFrameskip-v4'
  num_actions = env_utils.get_num_actions(game)
  print(f'Playing {game} with {num_actions} actions')
  model = models.ActorCritic(num_outputs=num_actions)
  ppo_lib.train(model, config, FLAGS.workdir)
Пример #2
0
def main(argv):
    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')
    config = FLAGS.config
    game = config.game + 'NoFrameskip-v4'
    num_actions = env_utils.get_num_actions(game)
    print(f'Playing {game} with {num_actions} actions')
    module = models.ActorCritic(num_outputs=num_actions)
    key = jax.random.PRNGKey(0)
    key, subkey = jax.random.split(key)
    initial_params = models.get_initial_params(subkey, module)
    optimizer = models.create_optimizer(initial_params, config.learning_rate)
    optimizer = ppo_lib.train(module, optimizer, config, FLAGS.logdir)
Пример #3
0
 def test_model(self):
     outputs = self.choose_random_outputs()
     module = models.ActorCritic(num_outputs=outputs)
     params = ppo_lib.get_initial_params(jax.random.PRNGKey(0), module)
     test_batch_size, obs_shape = 10, (84, 84, 4)
     random_input = np.random.random(size=(test_batch_size, ) + obs_shape)
     log_probs, values = agent.policy_action(module.apply, params,
                                             random_input)
     self.assertEqual(values.shape, (test_batch_size, 1))
     sum_probs = np.sum(np.exp(log_probs), axis=1)
     self.assertEqual(sum_probs.shape, (test_batch_size, ))
     np_testing.assert_allclose(sum_probs,
                                np.ones((test_batch_size, )),
                                atol=1e-6)
Пример #4
0
 def test_optimization_step(self):
     num_outputs = 4
     trn_data = self.generate_random_data(num_actions=num_outputs)
     clip_param = 0.1
     vf_coeff = 0.5
     entropy_coeff = 0.01
     lr = 2.5e-4
     batch_size = 256
     key = jax.random.PRNGKey(0)
     key, subkey = jax.random.split(key)
     module = models.ActorCritic(num_outputs)
     initial_params = models.get_initial_params(subkey, module)
     lr = 2.5e-4
     optimizer = models.create_optimizer(initial_params, lr)
     optimizer, _ = ppo_lib.train_step(module, optimizer, trn_data,
                                       clip_param, vf_coeff, entropy_coeff,
                                       lr, batch_size)
     self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer))
Пример #5
0
 def test_model(self):
     key = jax.random.PRNGKey(0)
     key, subkey = jax.random.split(key)
     outputs = self.choose_random_outputs()
     module = models.ActorCritic(num_outputs=outputs)
     initial_params = models.get_initial_params(subkey, module)
     lr = 2.5e-4
     optimizer = models.create_optimizer(initial_params, lr)
     self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer))
     test_batch_size, obs_shape = 10, (84, 84, 4)
     random_input = np.random.random(size=(test_batch_size, ) + obs_shape)
     log_probs, values = agent.policy_action(optimizer.target, module,
                                             random_input)
     self.assertEqual(values.shape, (test_batch_size, 1))
     sum_probs = np.sum(np.exp(log_probs), axis=1)
     self.assertEqual(sum_probs.shape, (test_batch_size, ))
     np_testing.assert_allclose(sum_probs,
                                np.ones((test_batch_size, )),
                                atol=1e-6)
Пример #6
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    FLAGS.log_dir = FLAGS.workdir
    FLAGS.stderrthreshold = 'info'
    logging.get_absl_handler().start_logging_to_file()

    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')
    config = FLAGS.config
    game = config.game + 'NoFrameskip-v4'
    num_actions = env_utils.get_num_actions(game)
    print(f'Playing {game} with {num_actions} actions')
    module = models.ActorCritic(num_outputs=num_actions)
    key = jax.random.PRNGKey(0)
    key, subkey = jax.random.split(key)
    initial_params = models.get_initial_params(subkey, module)
    optimizer = models.create_optimizer(initial_params, config.learning_rate)
    optimizer = ppo_lib.train(module, optimizer, config, FLAGS.workdir)
Пример #7
0
 def test_optimization_step(self):
     num_outputs = 4
     trn_data = self.generate_random_data(num_actions=num_outputs)
     clip_param = 0.1
     vf_coeff = 0.5
     entropy_coeff = 0.01
     batch_size = 256
     module = models.ActorCritic(num_outputs)
     initial_params = ppo_lib.get_initial_params(jax.random.PRNGKey(0),
                                                 module)
     config = ml_collections.ConfigDict({
         'learning_rate': 2.5e-4,
         'decaying_lr_and_clip_param': True,
     })
     state = ppo_lib.create_train_state(initial_params, module, config,
                                        1000)
     state, _ = ppo_lib.train_step(state,
                                   trn_data,
                                   batch_size,
                                   clip_param=clip_param,
                                   vf_coeff=vf_coeff,
                                   entropy_coeff=entropy_coeff)
     self.assertIsInstance(state, train_state.TrainState)
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # * Step 1: init data folders
    print("init data folders")

    # * Init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.mini_imagenet_folders(
    )

    # * Step 2: init neural networks
    print("init neural networks")

    feature_encoder = models.CNNEncoder()
    model = models.ActorCritic(FEATURE_DIM, RELATION_DIM, CLASS_NUM)

    #feature_encoder = torch.nn.DataParallel(feature_encoder)
    #actor = torch.nn.DataParallel(actor)
    #critic = torch.nn.DataParallel(critic)

    feature_encoder.train()
    model.train()

    feature_encoder.apply(models.weights_init)
    model.apply(models.weights_init)

    feature_encoder.to(device)
    model.to(device)

    cross_entropy = nn.CrossEntropyLoss()

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=10000,
                                       gamma=0.5)

    model_optim = torch.optim.Adam(model.parameters(), lr=2.5 * LEARNING_RATE)
    model_scheduler = StepLR(model_optim, step_size=10000, gamma=0.5)

    agent = a2cAgent.A2CAgent(GAMMA, ENTROPY_WEIGHT, CLASS_NUM, device)

    if os.path.exists(
            str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")

    if os.path.exists(
            str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        model.load_state_dict(
            torch.load(
                str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load actor network success")

    # * Step 3: build graph
    print("Training...")

    last_accuracy = 0.0
    loss_list = []
    number_of_query_image = 15
    for episode in range(EPISODE):
        #print(f"EPISODE : {episode}")
        losses = []

        for meta_batch in range(META_BATCH_RANGE):
            meta_env_states_list = []
            meta_env_labels_list = []
            model_fast_weight = OrderedDict(model.named_parameters())
            for inner_batch in range(INNER_BATCH_RANGE):
                # * Generate environment
                env_states_list = []
                env_labels_list = []
                for env in range(ENV_LENGTH):
                    task = tg.MiniImagenetTask(metatrain_character_folders,
                                               CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                                               number_of_query_image)
                    sample_dataloader = tg.get_mini_imagenet_data_loader(
                        task,
                        num_per_class=SAMPLE_NUM_PER_CLASS,
                        split="train",
                        shuffle=False)
                    batch_dataloader = tg.get_mini_imagenet_data_loader(
                        task,
                        num_per_class=number_of_query_image,
                        split="test",
                        shuffle=True)

                    samples, sample_labels = next(iter(sample_dataloader))
                    samples, sample_labels = samples.to(
                        device), sample_labels.to(device)

                    batches, batch_labels = next(iter(batch_dataloader))
                    batches, batch_labels = batches.to(
                        device), batch_labels.to(device)

                    inner_sample_features = feature_encoder(samples)
                    inner_sample_features = inner_sample_features.view(
                        CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 19, 19)
                    inner_sample_features = torch.sum(inner_sample_features,
                                                      1).squeeze(1)

                    inner_batch_features = feature_encoder(batches)
                    inner_sample_feature_ext = inner_sample_features.unsqueeze(
                        0).repeat(number_of_query_image * CLASS_NUM, 1, 1, 1,
                                  1)
                    inner_batch_features_ext = inner_batch_features.unsqueeze(
                        0).repeat(CLASS_NUM, 1, 1, 1, 1)
                    inner_batch_features_ext = torch.transpose(
                        inner_batch_features_ext, 0, 1)

                    inner_relation_pairs = torch.cat(
                        (inner_sample_feature_ext, inner_batch_features_ext),
                        2).view(-1, FEATURE_DIM * 2, 19, 19)
                    env_states_list.append(inner_relation_pairs)
                    env_labels_list.append(batch_labels)

                inner_env = a2cAgent.env(env_states_list, env_labels_list)
                inner_loss = agent.train(inner_env, model)
                inner_gradients = torch.autograd.grad(
                    inner_loss.mean(),
                    model_fast_weight.values(),
                    create_graph=True,
                    allow_unused=True)

                model_fast_weight = OrderedDict(
                    (name, param - INNER_LR * (0 if grad is None else grad))
                    for ((name, param), grad
                         ) in zip(model_fast_weight.items(), inner_gradients))

            model.weight = model_fast_weight
            for meta_env in range(META_ENV_LENGTH):
                task = tg.MiniImagenetTask(metatrain_character_folders,
                                           CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                                           number_of_query_image)
                sample_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="train",
                    shuffle=False)
                batch_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=number_of_query_image,
                    split="test",
                    shuffle=True)
                # * num_per_class : number of query images

                # * sample datas
                samples, sample_labels = next(iter(sample_dataloader))
                samples, sample_labels = samples.to(device), sample_labels.to(
                    device)
                # * Generate env for meta update
                batches, batch_labels = next(iter(batch_dataloader))
                # * init dataset
                # * sample_dataloader is to obtain previous samples for compare
                # * batch_dataloader is to batch samples for training
                batches, batch_labels = batches.to(device), batch_labels.to(
                    device)

                # * calculates features
                #feature_encoder.weight = feature_fast_weights

                sample_features = feature_encoder(samples)
                sample_features = sample_features.view(CLASS_NUM,
                                                       SAMPLE_NUM_PER_CLASS,
                                                       FEATURE_DIM, 19, 19)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                batch_features = feature_encoder(batches)

                # * calculate relations
                # * each batch sample link to every samples to calculate relations
                # * to form a 100 * 128 matrix for relation network
                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    number_of_query_image * CLASS_NUM, 1, 1, 1, 1)
                batch_features_ext = batch_features.unsqueeze(0).repeat(
                    CLASS_NUM, 1, 1, 1, 1)
                batch_features_ext = torch.transpose(batch_features_ext, 0, 1)
                relation_pairs = torch.cat(
                    (sample_features_ext, batch_features_ext),
                    2).view(-1, FEATURE_DIM * 2, 19, 19)

                meta_env_states_list.append(relation_pairs)
                meta_env_labels_list.append(batch_labels)

            meta_env = a2cAgent.env(meta_env_states_list, meta_env_labels_list)
            loss = agent.train(meta_env, model)
            losses.append(loss)

        feature_encoder_optim.zero_grad()
        model_optim.zero_grad()

        torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

        meta_batch_loss = torch.stack(losses).mean()
        meta_batch_loss.backward()

        feature_encoder_optim.step()
        model_optim.step()

        feature_encoder_scheduler.step()
        model_scheduler.step()

        if (episode + 1) % 100 == 0:
            mean_loss = meta_batch_loss.cpu().detach().numpy()
            print(f"episode : {episode+1}, meta_loss : {mean_loss:.4f}")
            loss_list.append(mean_loss)

        if (episode + 1) % 500 == 0:
            print("Testing...")
            total_reward = 0

            total_num_of_test_samples = 0
            for i in range(TEST_EPISODE):
                # * Generate env
                env_states_list = []
                env_labels_list = []

                number_of_query_image = 10
                task = tg.MiniImagenetTask(metatest_character_folders,
                                           CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                                           number_of_query_image)
                sample_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="train",
                    shuffle=False)
                test_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=number_of_query_image,
                    split="test",
                    shuffle=True)
                sample_images, sample_labels = next(iter(sample_dataloader))
                sample_images, sample_labels = sample_images.to(
                    device), sample_labels.to(device)

                test_images, test_labels = next(iter(test_dataloader))
                total_num_of_test_samples += len(test_labels)
                test_images, test_labels = test_images.to(
                    device), test_labels.to(device)

                # * calculate features
                sample_features = feature_encoder(sample_images)
                sample_features = sample_features.view(CLASS_NUM,
                                                       SAMPLE_NUM_PER_CLASS,
                                                       FEATURE_DIM, 19, 19)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                test_features = feature_encoder(test_images)

                # * calculate relations
                # * each batch sample link to every samples to calculate relations
                # * to form a 100x128 matrix for relation network

                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    number_of_query_image * CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = test_features.unsqueeze(0).repeat(
                    CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat(
                    (sample_features_ext, test_features_ext),
                    2).view(-1, FEATURE_DIM * 2, 19, 19)
                env_states_list.append(relation_pairs)
                env_labels_list.append(test_labels)

                test_env = a2cAgent.env(env_states_list, env_labels_list)
                rewards = agent.test(test_env, model)
                total_reward += rewards

            test_accuracy = total_reward / (1.0 * total_num_of_test_samples)

            mean_loss = np.mean(loss_list)

            print(f'mean loss : {mean_loss}')
            print("test accuracy : ", test_accuracy)

            writer.add_scalar('1.loss', mean_loss, episode + 1)
            writer.add_scalar('4.test accuracy', test_accuracy, episode + 1)

            loss_list = []

            if test_accuracy > last_accuracy:
                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./models/miniimagenet_feature_encoder_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot.pkl"))
                torch.save(
                    model.state_dict(),
                    str("./models/miniimagenet_actor_network_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot.pkl"))
                print("save networks for episode:", episode)
                last_accuracy = test_accuracy