Esempio n. 1
0
    def __init__(self, map_index, policy, value, oracle_network, writer,
                 write_op, num_epochs, gamma, plot_model, save_model,
                 save_name, num_episide, share_exp, oracle, share_exp_weight):

        self.map_index = map_index
        self.PGNetwork = policy
        self.VNetwork = value
        self.ZNetwork = oracle_network

        self.writer = writer
        self.write_op = write_op
        self.num_epochs = num_epochs
        self.gamma = gamma
        self.save_name = save_name
        self.plot_model = plot_model
        self.save_model = save_model

        self.num_episide = num_episide
        self.oracle = oracle
        self.share_exp = share_exp
        self.share_exp_weight = share_exp_weight

        self.env = Terrain(self.map_index)

        self.plot_figure = PlotFigure(self.save_name, self.env)

        self.rollout = Rollout(number_episode=self.num_episide,
                               map_index=self.map_index)
        self.epsilon = 0.2
Esempio n. 2
0
 def get_signals(self, action):
     from rollout import Rollout
     signal_action = {'run': 'rollout', 'revert': 'rollback'}[action]
     abort_signal = Rollout.get_signal(self.rollout_id, 'abort_%s' % signal_action)
     term_signal = Rollout.get_signal(self.rollout_id, 'term_%s' % signal_action)
     if not abort_signal and term_signal:
         raise Exception('Cannot run: one or more signals don\'t exist: '
                 'abort: %s, term: %s' % (abort_signal, term_signal))
     return abort_signal, term_signal
Esempio n. 3
0
def train_generator_PG(gen, gen_opt, dis, train_iter, num_batches):
    """
    The generator is trained using policy gradients, using the reward from the discriminator.
    Training is done for num_batches batches.
    """
    global pg_count
    global best_advbleu
    pg_count += 1
    num_sentences = 0
    total_loss = 0
    rollout = Rollout(gen, update_learning_rate)
    for i, data in enumerate(train_iter):
        if i == num_batches:
            break
        src_data_wrap = data.source
        ans = data.answer[0]
        # tgt_data = data.target[0].permute(1, 0)
        passage = src_data_wrap[0].permute(1, 0)

        if CUDA:
            scr_data = data.source[0].to(device)  # lengths x batch_size
            scr_lengths = data.source[1].to(device)
            ans = ans.to(device)
            ans_p = ans.permute(1, 0)
            src_data_wrap = (scr_data, scr_lengths, ans)
            passage = passage.to(device)
            passage = (passage, ans_p)

        num_sentences += scr_data.size(1)
        with torch.no_grad():
            samples, _ = gen.sample(src_data_wrap)        # 64 batch_size works best
            rewards = rollout.get_reward(samples, passage, src_data_wrap, rollout_size, dis, src_rev, rev, train_ref, tgt_pad)

        inp, target = helpers.prepare_generator_batch(samples, gpu=CUDA)

        gen_opt.zero_grad()
        pg_loss = gen.batchPGLoss(src_data_wrap, inp, target, rewards)
        pg_loss.backward()
        gen_opt.step()
        total_loss += pg_loss
        rollout.update_params() # TODO: DON'T KNOW WHY

    gen.eval()
    # print("Set gen to {0} mode".format('train' if model.decoder.dropout.training else 'eval'))
    valid_bleu = evaluation.evalModel(gen, val_iter, pg_count, rev, src_special, tgt_special, tgt_ref, src_rev)
    print('Validation bleu-4 = %g' % (valid_bleu * 100))
    if valid_bleu > best_advbleu:
        best_advbleu = valid_bleu
        torch.save(gen.state_dict(), 'advparams.pkl')
        print('save model')
    # train_bleu = evaluation.evalModel(gen, train_iter)
    # print('training bleu = %g' % (train_bleu * 100))
    gen.train()

    print("\npg_loss on %d bactches : %.4f" %(i+1, total_loss/num_batches))
Esempio n. 4
0
    def expand(self):
        # TO DO : self.node.possible_moves(), node.expand(move)
        print("Expand node")
        for move in self.node.board.legal_moves():
            # Add child
            child = self.node.expand_child(move)
            self.nodeList.append(child)
            # Calculate prior
            #child.prior = 0 #CNN.predict(child.board)

        # ADD ROLLOUT
        roll = Rollout(self.node)
        value = roll.lance_rollout()
        print(f"Rollout value : {value}")
        self.backpropagate(self.node, value)
Esempio n. 5
0
	def train_gan(self, backend):

		rollout = Rollout(self.generator, self.discriminator, self.update_rate)
		print('\nStart Adeversatial Training......')
		gen_optim, dis_optim = torch.optim.Adam(self.generator.parameters(), self.lr), torch.optim.Adam(self.discriminator.parameters(), self.lr)
		dis_criterion = util.to_cuda(nn.BCEWithLogitsLoss(size_average=False))
		gen_criterion = util.to_cuda(nn.CrossEntropyLoss(size_average=False, reduce=True))

		for epoch in range(self.gan_epochs):

			start = time.time()
			for _ in range(1):
				samples = self.generator.sample(self.batch_size, self.sequence_len) # (batch_size, sequence_len)
				zeros = util.to_var(torch.zeros(self.batch_size, 1).long()) # (batch_size, 1)
				inputs = torch.cat([samples, zeros], dim=1)[:, :-1] # (batch_size, sequence_len)
				rewards = rollout.reward(samples, 16) # (batch_size, sequence_len)
				rewards = util.to_var(torch.from_numpy(rewards))
				logits = self.generator(inputs) # (None, vocab_size, sequence_len)
				pg_loss = self.pg_loss(logits, samples, rewards)
				gen_optim.zero_grad()
				pg_loss.backward()
				gen_optim.step()

			print 'generator updated via policy gradient......'

			if epoch % 10 == 0:
				util.generate_samples(self.generator, self.batch_size, self.sequence_len, self.generate_sum, self.eval_file)
				eval_data = GenData(self.eval_file)
				eval_data_loader = DataLoader(eval_data, batch_size=self.batch_size, shuffle=True, num_workers=8)
				loss = self.eval_epoch(self.target_lstm, eval_data_loader, gen_criterion)
				print 'epoch: [{0:d}], true loss: [{1:.4f}]'.format(epoch, loss)



			for _ in range(1):
				util.generate_samples(self.generator, self.batch_size, self.sequence_len, self.generate_sum, self.fake_file)
				dis_data = DisData(self.real_file, self.fake_file)
				dis_data_loader = DataLoader(dis_data, batch_size=self.batch_size, shuffle=True, num_workers=8)
				for _ in range(1):
					loss = self.train_epoch(self.discriminator, dis_data_loader, dis_criterion, dis_optim)

			print 'discriminator updated via gan loss......'

			rollout.update_params()

			end = time.time()

			print 'time: [{:.3f}s/epoch] in {}'.format(end-start, backend)
Esempio n. 6
0
def train_ad(total_batch, g_steps, d_steps, k, generator_fake, rollout,
             start_token, discriminator, lr, generated_num, batch_size,
             positive_file, negative_file):
    print('Start adversarial training')
    for _ in range(total_batch):
        trian_g(g_steps, generator_fake, start_token, rollout, discriminator, lr)
        rollout = Rollout(generator_fake)
        train_d(d_steps, k, generator_fake, start_token, generated_num,
                batch_size, discriminator, positive_file, negative_file, lr)
Esempio n. 7
0
	def __init__(
			self,
			env,
			env_config,
			policy,
			value,
			oracle_network,
			share_exp,
			oracle):

		self.env_config = env_config
		self.PGNetwork = policy
		self.VNetwork = value
		self.ZNetwork = oracle_network
		self.rollout = Rollout(number_episode=args.rollouts, map_index=env_config.map_index)

		self.oracle = oracle
		self.share_exp = share_exp
		self.env = env
Esempio n. 8
0
    def __init__(self,
                 state_size,
                 action_size,
                 lr=1e-3,
                 gamma=0.99,
                 clipping_epsilon=0.1,
                 ppo_epochs=10,
                 minibatch_size=64,
                 rollout_length=1000,
                 gae_lambda=0.95):
        self.lr = lr
        self.clipping_epsilon = clipping_epsilon
        self.ppo_epochs = ppo_epochs
        self.minibatch_size = minibatch_size
        self.rollout_length = rollout_length

        self.policy = PolicyNet(state_size, action_size)
        self.value_estimator = ValueNet(state_size)
        self.rollout = Rollout(gamma=gamma, gae_lambda=gae_lambda)
Esempio n. 9
0
    def __init__(self, map_index, policies, writer, write_op, num_task,
                 num_iters, num_episode, num_epochs, gamma, lamb, plot_model,
                 save_model, save_name, share_exp, use_laser, use_gae,
                 noise_argmax):

        self.map_index = map_index
        self.PGNetwork = policies

        self.writer = writer
        self.write_op = write_op
        self.use_gae = use_gae

        self.num_task = num_task
        self.num_iters = num_iters
        self.num_epochs = num_epochs
        self.num_episode = num_episode

        self.gamma = gamma
        self.lamb = lamb
        self.save_name = save_name
        self.plot_model = plot_model
        self.save_model = save_model

        self.share_exp = share_exp
        self.noise_argmax = noise_argmax

        self.env = Terrain(self.map_index, use_laser)

        assert self.num_task <= self.env.num_task

        self.plot_figure = PlotFigure(self.save_name, self.env, self.num_task)

        self.rollout = Rollout(
            num_task=self.num_task,
            num_episode=self.num_episode,
            num_iters=self.num_iters,
            map_index=self.map_index,
            use_laser=use_laser,
            noise_argmax=self.noise_argmax,
        )
    def __init__(self, sess, env, a_space, s_space):
        self.env = env
        self.sess = sess
        self.a_space = a_space
        self.s_space = s_space
        self.lr = 2e-4
        self.train_episodes = 5000
        self.gamma = 0.99
        self.epsilon = 0.2
        self.horizen = 128
        self.batch = 32
        self.rollout = Rollout(self.batch)
        self.entropy_coff = 0.02
        self.value_coff = 1

        # input tf variable
        self._init_input()
        # Actor-Critic output
        self._init_net_out()
        # operation
        self._init_op()

        self.sess.run(tf.global_variables_initializer())
Esempio n. 11
0
    def __init__(
            self,
            corpus: Corpus,
            # mean: torch.FloatTensor = torch.zeros(1024),
            # std: torch.FloatTensor = torch.ones(1024),
            low: float = -1,
            high: float = +1,
            hidden_size: int = 100,
            cnn_output_size: int = 4096,
            input_encoding_size: int = 512,
            max_sentence_length: int = 18,
            num_layers: int = 1,
            dropout: float = 0):
        super().__init__()
        self.cnn_output_size = cnn_output_size
        self.input_encoding_size = input_encoding_size
        self.max_sentence_length = max_sentence_length
        self.embed = corpus
        self.dropout = dropout
        self.num_layers = num_layers
        self.hidden_size = hidden_size  # mean.shape[0]
        # self.dist = Normal(Variable(mean), Variable(std))  # noise variable
        self.dist = Uniform(low, high)  # noise variable
        self.lstm = nn.LSTM(input_size=corpus.embed_size,
                            hidden_size=self.input_encoding_size,
                            num_layers=self.num_layers,
                            batch_first=True,
                            dropout=self.dropout)

        self.output_linear = nn.Linear(self.input_encoding_size,
                                       corpus.vocab_size)
        self.features_linear = nn.Sequential(
            # nn.Linear(cnn_output_size + len(mean), input_encoding_size),
            nn.Linear(cnn_output_size + self.hidden_size, input_encoding_size),
            nn.ReLU())
        self.rollout = Rollout(max_sentence_length, corpus, self)
Esempio n. 12
0
def main(batch_size):
    if batch_size is None:
        batch_size = 1
    x, vocabulary, reverse_vocab, sentence_lengths = read_sampleFile()
    if batch_size > len(x):
        batch_size = len(x)
    start_token = vocabulary['START']
    end_token = vocabulary['END']
    pad_token = vocabulary['PAD']
    ignored_tokens = [start_token, end_token, pad_token]
    vocab_size = len(vocabulary)
    
    generator = pretrain_generator(x, start_token=start_token, 
                    end_token=end_token,ignored_tokens=ignored_tokens,
                    sentence_lengths=sentence_lengths,batch_size=batch_size,
                    vocab_size=vocab_size)
    x_gen = generator.generate(start_token=start_token, ignored_tokens=ignored_tokens, 
                               batch_size=len(x))
    discriminator = train_discriminator_wrapper(x, x_gen, batch_size, vocab_size)
    rollout = Rollout(generator, r_update_rate=0.8)
    rollout.to(DEVICE)
    for total_batch in range(TOTAL_BATCH):
        print('batch: {}'.format(total_batch))
        for it in range(1):
            samples = generator.generate(start_token=start_token, 
                    ignored_tokens=ignored_tokens, batch_size=batch_size)
            # Take average of ROLLOUT_ITER times of rewards.
            #   The more times a [0,1] class (positive, real data) 
            #   is returned, the higher the reward. 
            rewards = getReward(samples, rollout, discriminator)
            (generator, y_prob_all, y_output_all) = train_generator(model=generator, x=samples, 
                    reward=rewards, iter_n_gen=1, batch_size=batch_size, sentence_lengths=sentence_lengths)
        
        rollout.update_params(generator)
        
        for iter_n_dis in range(DIS_NUM_EPOCH):
            print('iter_n_dis: {}'.format(iter_n_dis))
            x_gen = generator.generate(start_token=start_token, ignored_tokens=ignored_tokens, 
                               batch_size=len(x))
            discriminator = train_discriminator_wrapper(x, x_gen, batch_size,vocab_size)
    
    log = openLog('genTxt.txt')
    num = generator.generate(batch_size=batch_size)
    words_all = decode(num, reverse_vocab, log)
    log.close()
    print(words_all)
Esempio n. 13
0
def train_GAN(conf_data):
    """Training Process for GAN.
    
    Parameters
    ----------
    conf_data: dict
        Dictionary containing all parameters and objects.       

    Returns
    -------
    conf_data: dict
        Dictionary containing all parameters and objects.       

    """
    seq = conf_data['GAN_model']['seq']
    if seq == 1:
        pre_epoch_num = conf_data['generator']['pre_epoch_num']
        GENERATED_NUM = 10000
        EVAL_FILE = 'eval.data'
        POSITIVE_FILE = 'real.data'
        NEGATIVE_FILE = 'gene.data'
    temp = 1  #TODO Determines how many times is the discriminator updated. Take this as a value input
    epochs = int(conf_data['GAN_model']['epochs'])
    if seq == 0:
        dataloader = conf_data['data_learn']
    mini_batch_size = int(conf_data['GAN_model']['mini_batch_size'])
    data_label = int(conf_data['GAN_model']['data_label'])
    cuda = conf_data['cuda']
    g_latent_dim = int(conf_data['generator']['latent_dim'])
    classes = int(conf_data['GAN_model']['classes'])

    w_loss = int(conf_data['GAN_model']['w_loss'])

    clip_value = float(conf_data['GAN_model']['clip_value'])
    n_critic = int(conf_data['GAN_model']['n_critic'])

    lambda_gp = int(conf_data['GAN_model']['lambda_gp'])

    log_file = open(conf_data['performance_log'] + "/log.txt", "w+")
    #Covert these to parameters of the config data
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    conf_data['Tensor'] = Tensor
    LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
    conf_data['LongTensor'] = LongTensor
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    conf_data['FloatTensor'] = FloatTensor

    conf_data['epochs'] = epochs

    #print ("Just before training")
    if seq == 1:  #TODO: Change back to 1
        target_lstm = TargetLSTM(conf_data['GAN_model']['vocab_size'],
                                 conf_data['generator']['embedding_dim'],
                                 conf_data['generator']['hidden_dim'],
                                 conf_data['cuda'])
        if cuda == True:
            target_lstm = target_lstm.cuda()
        conf_data['target_lstm'] = target_lstm
        gen_data_iter = GenDataIter('real.data', mini_batch_size)
        generator = conf_data['generator_model']
        discriminator = conf_data['discriminator_model']
        g_loss_func = conf_data['generator_loss']
        d_loss_func = conf_data['discriminator_loss']
        optimizer_D = conf_data['discriminator_optimizer']
        optimizer_G = conf_data['generator_optimizer']
        #print('Pretrain with MLE ...')
        for epoch in range(pre_epoch_num):  #TODO: Change the range
            loss = train_epoch(generator, gen_data_iter, g_loss_func,
                               optimizer_G, conf_data, 'g')
            print('Epoch [%d] Model Loss: %f' % (epoch, loss))
            generate_samples(generator, mini_batch_size, GENERATED_NUM,
                             EVAL_FILE, conf_data)
            eval_iter = GenDataIter(EVAL_FILE, mini_batch_size)
            loss = eval_epoch(target_lstm, eval_iter, g_loss_func, conf_data)
            print('Epoch [%d] True Loss: %f' % (epoch, loss))

        dis_criterion = d_loss_func
        dis_optimizer = optimizer_D
        #TODO: Understand why the below two code line were there ?
        # if conf_data['cuda']:
        #     dis_criterion = dis_criterion.cuda()

        #print('Pretrain Dsicriminator ...')
        dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE,
                                    mini_batch_size)
        for epoch in range(5):  #TODO: change back 5
            generate_samples(generator, mini_batch_size, GENERATED_NUM,
                             NEGATIVE_FILE, conf_data)
            dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE,
                                        mini_batch_size)
            for _ in range(3):  #TODO: change back 3
                loss = train_epoch(discriminator, dis_data_iter, dis_criterion,
                                   dis_optimizer, conf_data, 'd')
                print('Epoch [%d], loss: %f' % (epoch, loss))
        conf_data['generator_model'] = generator
        conf_data['discriminator_model'] = discriminator
        torch.save(conf_data['generator_model'].state_dict(),
                   conf_data['save_model_path'] + '/Seq/' + 'pre_generator.pt')
        torch.save(
            conf_data['discriminator_model'].state_dict(),
            conf_data['save_model_path'] + '/Seq/' + 'pre_discriminator.pt')

        conf_data['rollout'] = Rollout(generator, 0.8)

    for epoch in range(epochs):
        conf_data['epoch'] = epoch
        if seq == 0:
            to_iter = dataloader
        elif seq == 1:  #TODO: Change this back to 1
            to_iter = [1]

        for i, iterator in enumerate(to_iter):
            optimizer_D = conf_data['discriminator_optimizer']
            optimizer_G = conf_data['generator_optimizer']

            generator = conf_data['generator_model']
            discriminator = conf_data['discriminator_model']

            g_loss_func = conf_data['generator_loss']
            d_loss_func = conf_data['discriminator_loss']

            # if aux = 1:

            #print ("Reached here --------------> ")
            conf_data['iterator'] = i
            if seq == 0:

                if data_label == 1:
                    imgs, labels = iterator
                else:
                    imgs = iterator
                # Adversarial ground truths
                valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0),
                                 requires_grad=False)
                conf_data['valid'] = valid
                fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0),
                                requires_grad=False)
                conf_data['fake'] = fake
                # Configure input
                real_imgs = Variable(imgs.type(Tensor))

                if data_label == 1:
                    labels = Variable(labels.type(LongTensor))
                # Sample noise as generator input
                z = Variable(
                    Tensor(
                        np.random.normal(0, 1, (imgs.shape[0], g_latent_dim))))
                if classes > 0:
                    gen_labels = Variable(
                        LongTensor(np.random.randint(0, classes,
                                                     imgs.shape[0])))
                    conf_data['gen_labels'] = gen_labels
            # elif seq == 1: #If yes seqGAN
            #     # samples = generator.sample(mini_batch_size,conf_data['generator']['sequece_length'])
            #     # zeros = torch.zeros((mini_batch_size,1)).type(LongTensor)
            #     # imgs = Variable(torch.cat([zeros,samples.data]),dim=1)[:,:-1].contiguous() #TODO: change imgs to inps all, to make more sense of the code
            #     # targets = Variable(sample.data).contiguous().view((-1,))
            #     # rewards = rollout.get_reward(sample,16,discriminator)
            #     # rewards = Variable(Tensor(rewards))
            #     # prob = generator.forward(inputs)
            #     # loss = gen_gan_loss(prob)
            #     pass
            #     #optimizer_G

            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()
            if seq == 1:  #TODO change this back to 1
                dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE,
                                            mini_batch_size)
            for i in range(
                    temp
            ):  # TODO: Make this a parameter -> for x updates --> I am read the stored models here as well. Should I reamove this ???
                optimizer_D = conf_data['discriminator_optimizer']
                optimizer_G = conf_data['generator_optimizer']

                generator = conf_data['generator_model']
                discriminator = conf_data['discriminator_model']

                g_loss_func = conf_data['generator_loss']
                d_loss_func = conf_data['discriminator_loss']
                if classes <= 0:
                    #print ("Reached here 2 --------------> ")
                    if seq == 0:
                        gen_imgs = generator(z)
                        # Measure discriminator's ability to classify real from generated samples
                        #Real images
                        real_validity = discriminator(real_imgs)
                        #Fake images
                        fake_validity = discriminator(gen_imgs.detach())
                    if seq == 1:
                        generate_samples(generator, mini_batch_size,
                                         GENERATED_NUM, NEGATIVE_FILE,
                                         conf_data)
                        dis_data_iter = DisDataIter(POSITIVE_FILE,
                                                    NEGATIVE_FILE,
                                                    mini_batch_size)
                        loss = train_epoch(discriminator, dis_data_iter,
                                           d_loss_func, optimizer_D, conf_data,
                                           'd')
                        conf_data['d_loss'] = loss
                        #exit()

                else:
                    if seq == 0:
                        gen_imgs = generator(z, gen_labels)
                        real_validity = discriminator(real_imgs, labels)
                        fake_validity = discriminator(gen_imgs.detach(),
                                                      labels)

                if seq == 0:
                    conf_data['gen_imgs'] = gen_imgs
                if seq == 0:
                    if w_loss == 0:
                        real_loss = d_loss_func.loss(real_validity, valid)
                        fake_loss = d_loss_func.loss(fake_validity, fake)
                        d_loss = (real_loss + fake_loss) / 2
                    elif w_loss == 1:
                        d_loss = -d_loss_func.loss(real_validity,
                                                   valid) + d_loss_func.loss(
                                                       fake_validity, fake)
                        if lambda_gp > 0:
                            conf_data['real_data_sample'] = real_imgs.data
                            conf_data['fake_data_sample'] = gen_imgs.data
                            conf_data = compute_gradient_penalty(conf_data)
                            gradient_penalty = conf_data['gradient_penalty']
                            d_loss = d_loss + lambda_gp * gradient_penalty
                    conf_data['d_loss'] = d_loss
                    d_loss.backward()
                    optimizer_D.step()

                if clip_value > 0:
                    # Clip weights of discriminator
                    for p in discriminator.parameters():
                        p.data.clamp_(-clip_value, clip_value)

            # -----------------
            #  Train Generator
            # -----------------
            conf_data['generator_model'] = generator
            conf_data['discriminator_model'] = discriminator

            #Next 4 lines were recently added maybe have to remove this.
            conf_data['optimizer_G'] = optimizer_G
            conf_data['optimizer_D'] = optimizer_D
            conf_data['generator_loss'] = g_loss_func
            conf_data['discriminator_loss'] = d_loss_func
            if seq == 0:
                conf_data['noise'] = z

            if n_critic <= 0:
                conf_data = training_fucntion_generator(conf_data)
            elif n_critic > 0:
                # Train the generator every n_critic iterations
                if i % n_critic == 0:
                    conf_data = training_fucntion_generator(conf_data)
            #exit()

        # print ("------------------ Here (train_GAN.py)")

            if seq == 0:
                batches_done = epoch * len(dataloader) + i
                if batches_done % int(conf_data['sample_interval']) == 0:
                    if classes <= 0:
                        # print ("Here")
                        # print (type(gen_imgs.data[:25]))
                        # print (gen_imgs.data[:25].shape)
                        save_image(gen_imgs.data[:25],
                                   conf_data['result_path'] +
                                   '/%d.png' % batches_done,
                                   nrow=5,
                                   normalize=True)
                    elif classes > 0:
                        sample_image(10, batches_done, conf_data)
        if seq == 0:
            log_file.write("[Epoch %d/%d] [D loss: %f] [G loss: %f] \n" %
                           (epoch, epochs, conf_data['d_loss'].item(),
                            conf_data['g_loss'].item()))
        elif seq == 1:
            # print ("Done")
            log_file.write(
                "[Epoch %d/%d] [D loss: %f] [G loss: %f] \n" %
                (epoch, epochs, conf_data['d_loss'], conf_data['g_loss']))
    conf_data['generator_model'] = generator
    conf_data['discriminator_model'] = discriminator
    conf_data['log_file'] = log_file
    return conf_data
Esempio n. 14
0
def main(unused_argv):
    config_train = training_config()
    config_gen = generator_config()
    config_dis = discriminator_config()
    np.random.seed(config_train.seed)
    assert config_train.start_token == 0

    # Build dataloader for generaotr, testing and discriminator
    gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    dis_data_loader = Dis_Data_loader(config_dis.dis_batch_size)

    # Build generator and its rollout
    generator = Generator(config=config_gen)
    generator.build()
    rollout_gen = Rollout(config=config_gen)

    # Build target LSTM
    target_params = pickle.load(open('save/target_params_py3.pkl', 'rb'))
    target_lstm = TARGET_LSTM(config=config_gen, params=target_params)  # The oracle model

    # Build discriminator
    discriminator = Discriminator(config=config_dis)
    discriminator.build_discriminator()

    # Build optimizer op for pretraining
    pretrained_optimizer = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    var_pretrained = [v for v in tf.trainable_variables() if
                      'teller' in v.name]  # Using name 'teller' here to prevent name collision of target LSTM
    gradients, variables = zip(
        *pretrained_optimizer.compute_gradients(generator.pretrain_loss, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    gen_pre_upate = pretrained_optimizer.apply_gradients(zip(gradients, variables))

    # Initialize all variables
    sess = tf.Session(config=config_hardware)
    sess.run(tf.global_variables_initializer())

    # Initalize data loader of generator
    generate_samples(sess, target_lstm, config_train.batch_size, config_train.generated_num, config_train.positive_file)
    gen_data_loader.create_batches(config_train.positive_file)

    # Start pretraining
    log = open('save/experiment-log.txt', 'w')
    print('Start pre-training generator...')
    log.write('pre-training...\n')
    for epoch in tqdm(range(config_train.pretrained_epoch_num), desc='Pre-Training(Generator)'):
        gen_data_loader.reset_pointer()
        for it in range(gen_data_loader.num_batch):
            batch = gen_data_loader.next_batch()
            _, g_loss = sess.run([gen_pre_upate, generator.pretrain_loss], feed_dict={generator.input_seqs_pre: batch,
                                                                                      generator.input_seqs_mask: np.ones_like(
                                                                                          batch)})
        if epoch % config_train.test_per_epoch == 0:
            generate_samples(sess, generator, config_train.batch_size, config_train.generated_num,
                             config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            # print('pre-train epoch ', epoch, 'test_loss ', test_loss)
            loss_info = 'pre-train epoch ' + str(epoch) + ' test_loss ' + str(test_loss)
            tqdm.write(loss_info)
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
            log.write(buffer)

    print('Start pre-training discriminator...')
    for _ in tqdm(range(config_train.dis_update_time_pre), desc='Pre-Training(Discriminator)'):
        generate_samples(sess, generator, config_train.batch_size, config_train.generated_num,
                         config_train.negative_file)
        dis_data_loader.load_train_data(config_train.positive_file, config_train.negative_file)
        for _ in range(config_train.dis_update_epoch_pre):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: config_dis.dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    # Build optimizer op for adversarial training
    train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    gradients, variables = zip(*train_adv_opt.compute_gradients(generator.gen_loss_adv, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables))

    # Initialize global variables of optimizer for adversarial training
    uninitialized_var = [e for e in tf.global_variables() if e not in tf.trainable_variables()]
    init_vars_uninit_op = tf.variables_initializer(uninitialized_var)
    sess.run(init_vars_uninit_op)

    # Start adversarial training
    for total_batch in tqdm(range(config_train.total_batch), desc='Adversarial-Training'):
        for _ in tqdm(range(config_train.gen_update_time), desc='Adversarial Generate Update'):
            samples = sess.run(generator.sample_word_list_reshape)
            feed = {"pred_seq_rollout:0": samples}
            reward_rollout = []
            # calcuate the reward given in the specific step t by roll out
            for iter_roll in range(config_train.rollout_num):
                rollout_list = sess.run(rollout_gen.sample_rollout_step, feed_dict=feed)
                rollout_list_stack = np.vstack(rollout_list)  # shape: #batch_size * #rollout_step, #sequence length
                reward_rollout_seq = sess.run(discriminator.ypred_for_auc,
                                              feed_dict={discriminator.input_x: rollout_list_stack,
                                                         discriminator.dropout_keep_prob: 1.0})
                reward_last_tok = sess.run(discriminator.ypred_for_auc, feed_dict={discriminator.input_x: samples,
                                                                                   discriminator.dropout_keep_prob: 1.0})
                reward_allseq = np.concatenate((reward_rollout_seq, reward_last_tok), axis=0)[:, 1]
                reward_tmp = []
                for r in range(config_gen.gen_batch_size):
                    reward_tmp.append(reward_allseq[range(r, config_gen.gen_batch_size * config_gen.sequence_length,
                                                          config_gen.gen_batch_size)])
                reward_rollout.append(np.array(reward_tmp))
            rewards = np.sum(reward_rollout, axis=0) / config_train.rollout_num
            _, gen_loss = sess.run([train_adv_update, generator.gen_loss_adv],
                                   feed_dict={generator.input_seqs_adv: samples,
                                              generator.rewards: rewards})
        if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1:
            generate_samples(sess,
                             generator,
                             config_train.batch_size,
                             config_train.generated_num,
                             config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
            loss_info = 'total_batch: ' + str(total_batch) + 'test_loss: ' + str(test_loss)
            tqdm.write(loss_info)
            log.write(buffer)

        for _ in tqdm(range(config_train.dis_update_time_adv), desc='Adversarial Discriminator Update'):
            generate_samples(sess, generator,
                             config_train.batch_size,
                             config_train.generated_num,
                             config_train.negative_file)
            dis_data_loader.load_train_data(config_train.positive_file, config_train.negative_file)

            for _ in range(config_train.dis_update_epoch_adv):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: config_dis.dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)
    log.close()
Esempio n. 15
0
    def __init__(self, training_scene, training_objects, config, arguments):

        self.config = config
        self.arguments = arguments

        self.training_scene = training_scene
        self.training_objects = training_objects

        self.use_gae = arguments.get('use_gae')
        self.num_epochs = arguments.get('num_epochs')
        self.num_episodes = arguments.get('num_episodes')
        self.num_iters = arguments.get('num_iters')
        self.gamma = arguments.get('gamma')
        self.lamb = arguments.get('lamb')
        self.lr = arguments.get('lr')
        self.joint_loss = arguments.get('joint_loss')
        self.ec = arguments.get('ec')
        self.vc = arguments.get('vc')
        self.max_grad_norm = arguments.get('max_gradient_norm')
        self.dropout = arguments.get('dropout')
        self.decay = arguments.get('decay')
        self.reuse = arguments.get('share_latent')
        self.gpu_fraction = arguments.get('gpu_fraction')

        assert len(
            training_objects) == 2, "> 2 sharing agents are not supported yet."
        self.env = AI2ThorDumpEnv(training_scene, training_objects[0], config,
                                  arguments)

        sharing = self.env.h5_file["_".join(training_objects)][()].tolist()
        non_sharing = list(
            set(list(range(self.env.h5_file['locations'].shape[0]))) -
            set(sharing))

        self.sharing = dict(
            zip(sharing + non_sharing,
                [1] * len(sharing) + [0] * len(non_sharing)))

        self.rollouts = []
        for obj in training_objects:
            self.rollouts.append(
                Rollout(training_scene, obj, config, arguments))

        tf.reset_default_graph()

        self.PGNetworks = []
        for i in range(2):
            agent = A2C(name='A2C_' + str(i),
                        state_size=self.env.features.shape[1],
                        action_size=self.env.action_space,
                        history_size=arguments['history_size'],
                        embedding_size=-1 if arguments['mode'] != 2 else 300,
                        entropy_coeff=self.ec,
                        value_function_coeff=self.vc,
                        max_gradient_norm=self.max_grad_norm,
                        dropout=self.dropout,
                        joint_loss=self.joint_loss,
                        learning_rate=self.lr,
                        decay=self.decay,
                        reuse=bool(self.reuse))

            if self.decay:
                agent.set_lr_decay(
                    self.lr,
                    self.num_epochs * self.num_episodes * self.num_iters)

            print("\nInitialized network with {} trainable weights.".format(
                len(agent.find_trainable_variables('A2C_' + str(i), True))))
            self.PGNetworks.append(agent)

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=self.gpu_fraction)

        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        self.sess.run(tf.global_variables_initializer())

        self.saver = tf.train.Saver()

        timer = "{}_{}_{}".format(str(datetime.now()).replace(" ", "-").replace(".", "-").replace(":", "-"), \
               training_scene, "_".join(training_objects))
        self.log_folder = os.path.join(arguments.get('logging'), timer)
        self.writer = tf.summary.FileWriter(self.log_folder)

        self.timer = timer

        test_name = training_scene
        for i in range(len(training_objects)):
            tf.summary.scalar(
                test_name + "/" + training_objects[i] + "/rewards",
                self.PGNetworks[i].mean_reward)
            tf.summary.scalar(
                test_name + "/" + training_objects[i] + "/success_rate",
                self.PGNetworks[i].success_rate)
            tf.summary.scalar(
                test_name + "/" + training_objects[i] + "/redundants",
                self.PGNetworks[i].mean_redundant)

        self.write_op = tf.summary.merge_all()
Esempio n. 16
0
class ConditionalGenerator(nn.Module):
    def __init__(
            self,
            corpus: Corpus,
            # mean: torch.FloatTensor = torch.zeros(1024),
            # std: torch.FloatTensor = torch.ones(1024),
            low: float = -1,
            high: float = +1,
            hidden_size: int = 100,
            cnn_output_size: int = 4096,
            input_encoding_size: int = 512,
            max_sentence_length: int = 18,
            num_layers: int = 1,
            dropout: float = 0):
        super().__init__()
        self.cnn_output_size = cnn_output_size
        self.input_encoding_size = input_encoding_size
        self.max_sentence_length = max_sentence_length
        self.embed = corpus
        self.dropout = dropout
        self.num_layers = num_layers
        self.hidden_size = hidden_size  # mean.shape[0]
        # self.dist = Normal(Variable(mean), Variable(std))  # noise variable
        self.dist = Uniform(low, high)  # noise variable
        self.lstm = nn.LSTM(input_size=corpus.embed_size,
                            hidden_size=self.input_encoding_size,
                            num_layers=self.num_layers,
                            batch_first=True,
                            dropout=self.dropout)

        self.output_linear = nn.Linear(self.input_encoding_size,
                                       corpus.vocab_size)
        self.features_linear = nn.Sequential(
            # nn.Linear(cnn_output_size + len(mean), input_encoding_size),
            nn.Linear(cnn_output_size + self.hidden_size, input_encoding_size),
            nn.ReLU())
        self.rollout = Rollout(max_sentence_length, corpus, self)

    def init_hidden(self, image_features):

        # generate rand
        z = torch.zeros(image_features.shape[0], self.hidden_size).cuda()

        # hidden of shape (num_layers * num_directions, batch, hidden_size)
        hidden = self.features_linear(
            torch.cat((image_features, z), 1).unsqueeze(0))

        # cell of shape (num_layers * num_directions, batch, hidden_size)
        cell = torch.zeros(1, image_features.shape[0],
                           self.input_encoding_size)

        return hidden.cuda(), cell.cuda()

    def forward(self, features, captions):
        states = self.init_hidden(features)
        hiddens, _ = self.lstm(captions, states)
        outputs = self.output_linear(hiddens[0])
        return outputs

    def init_hidden_noise(self, image_features):
        # z = torch.zeros(image_features.shape[0], self.hidden_size).cuda()
        z = self.dist.sample(
            (image_features.shape[0], self.hidden_size)).cuda()

        # hidden of shape (num_layers * num_directions, batch, hidden_size)
        hidden = self.features_linear(
            torch.cat((image_features, z), 1).unsqueeze(0))

        # cell of shape (num_layers * num_directions, batch, hidden_size)
        cell = Variable(
            torch.zeros(image_features.shape[0],
                        self.input_encoding_size).unsqueeze(0))

        return hidden.cuda(), cell.cuda()

    def forward_noise(self, features, captions):
        states = self.init_hidden_noise(features)
        hiddens, _ = self.lstm(captions, states)
        outputs = self.output_linear(hiddens[0])
        return outputs

    def reward_forward(self,
                       image_features,
                       evaluator,
                       monte_carlo_count=16,
                       steps=1):
        # self.lstm.flatten_parameters()
        batch_size = image_features.size(0)
        hidden = self.init_hidden_noise(image_features)
        # embed the start symbol
        inputs = self.embed.word_embeddings([self.embed.START_SYMBOL] *
                                            batch_size).unsqueeze(1).cuda()
        rewards = torch.zeros(batch_size, self.max_sentence_length - 1)
        # rewards[:, 0] = torch.ones(batch_size)

        props = torch.zeros(batch_size, self.max_sentence_length - 1)
        # props[:, 0] = torch.ones(batch_size)

        current_generated = inputs
        self.rollout.update(self)
        for i in range(self.max_sentence_length - 1):
            _, hidden = self.lstm(inputs, hidden)
            outputs = self.output_linear(hidden[0]).squeeze(0)
            outputs = F.softmax(outputs, -1)

            predicted = outputs.multinomial(1)
            prop = torch.gather(outputs, 1, predicted)
            props[:, i] = prop.view(-1)
            # m = Categorical(outputs)
            # predicted = m.sample()
            # props[:, i] = -m.log_prob(predicted)

            # embed the next inputs, unsqueeze is required cause of shape (batch_size, 1, embedding_size)
            inputs = self.embed.word_embeddings_from_indices(
                predicted.view(-1).cpu().data.numpy()).unsqueeze(1).cuda()
            current_generated = torch.cat([current_generated, inputs], dim=1)
            reward = self.rollout.reward2(current_generated, image_features,
                                          hidden, monte_carlo_count, evaluator,
                                          steps)
            rewards[:, i] = reward.view(-1)
        return rewards, props

    def sample(self, image_features, return_sentence=True):
        batch_size = image_features.size(0)

        # init the result with zeros and lstm states
        result = []
        hidden = self.init_hidden_noise(image_features)

        # embed the start symbol
        # inputs = self.embed.word_embeddings(["car"] * batch_size).unsqueeze(1).cuda()
        inputs = self.embed.word_embeddings([self.embed.START_SYMBOL] *
                                            batch_size).unsqueeze(1).cuda()
        result.append(self.embed.START_SYMBOL)

        for i in range(self.max_sentence_length - 1):
            inputs = Variable(inputs)
            _, hidden = self.lstm(inputs, hidden)
            outputs = self.output_linear(hidden[0]).squeeze(0)
            predicted = outputs.max(-1)[1]

            # embed the next inputs, unsqueeze is required 'cause of shape (batch_size, 1, embedding_size)
            inputs = self.embed.word_embeddings_from_indices(
                predicted.cpu().data.numpy()).unsqueeze(1).cuda()

            # store the result
            result.append(
                self.embed.word_from_index(predicted.cpu().numpy()[0]))

        if return_sentence:
            result = " ".join(result)  # .split(self.embed.END_SYMBOL)[0]

        return result

    def sample_single_with_embedding(self, image_features):
        batch_size = image_features.size(0)

        # init the result with zeros, and lstm states
        result = torch.zeros(self.max_sentence_length, self.embed.embed_size)
        hidden = self.init_hidden_noise(image_features)

        inputs = self.embed.word_embeddings([self.embed.START_SYMBOL] *
                                            batch_size).unsqueeze(1).cuda()

        for i in range(self.max_sentence_length):
            result[i] = inputs.squeeze(1)
            _, hidden = self.lstm(inputs, hidden)
            outputs = self.output_linear(hidden[0]).squeeze(0)
            predicted = outputs.max(-1)[1]

            # embed the next inputs, unsqueeze is required 'cause of shape (batch_size, 1, embedding_size)
            inputs = self.embed.word_embeddings_from_indices(
                predicted.cpu().data.numpy()).unsqueeze(1).cuda()

        return result

    def sample_with_embedding(self, images_features):
        batch_size = images_features.size(0)

        # init the result with zeros and lstm states
        result = torch.zeros(batch_size, self.max_sentence_length,
                             self.embed.embed_size).cuda()
        hidden = self.init_hidden_noise(images_features)

        # embed the start symbol
        inputs = self.embed.word_embeddings([self.embed.START_SYMBOL] *
                                            batch_size).unsqueeze(1).cuda()

        for i in range(self.max_sentence_length):
            # store the result
            result[:, i] = inputs.squeeze(1)
            inputs = Variable(inputs)
            _, hidden = self.lstm(inputs, hidden)
            outputs = self.output_linear(hidden[0]).squeeze(0)
            predicted = outputs.max(-1)[1]

            # embed the next inputs, unsqueeze is required 'cause of shape (batch_size, 1, embedding_size)
            inputs = self.embed.word_embeddings_from_indices(
                predicted.cpu().data.numpy()).unsqueeze(1).cuda()

        return Variable(result)

    def beam_sample(self, image_features, beam_size=5):
        batch_size = image_features.size(0)
        beam_searcher = BeamSearch(beam_size, batch_size, 17)

        # init the result with zeros and lstm states
        states = self.init_hidden_noise(image_features)
        states = (states[0].repeat(1, beam_size, 1).cuda(),
                  states[1].repeat(1, beam_size, 1).cuda())

        # embed the start symbol
        words_feed = self.embed.word_embeddings([self.embed.START_SYMBOL] * batch_size) \
            .repeat(beam_size, 1).unsqueeze(1).cuda()

        for i in range(self.max_sentence_length):
            hidden, states = self.lstm(words_feed, states)
            outputs = self.output_linear(hidden.squeeze(1))
            beam_indices, words_indices = beam_searcher.expand_beam(
                outputs=outputs)

            if len(beam_indices) == 0 or i == 15:
                generated_captions = beam_searcher.get_results()[:, 0]
                outcaps = self.embed.words_from_indices(
                    generated_captions.cpu().numpy())
            else:
                words_feed = torch.stack([
                    self.embed.word_embeddings_from_indices(words_indices)
                ]).view(beam_size, 1, -1).cuda()
        return " ".join(outcaps)  # .split(self.embed.END_SYMBOL)[0]

    def freeze(self):
        for param in self.parameters():
            param.requires_grad = False

    def unfreeze(self):
        for param in self.parameters():
            param.requires_grad = True

    def save(self):
        torch.save({"state_dict": self.state_dict()},
                   FilePathManager.resolve("models/generator00001111.pth"))

    @staticmethod
    def load(corpus: Corpus,
             path: str = "models/generator.pth",
             max_sentence_length=17):
        state_dict = torch.load(FilePathManager.resolve(path))
        state_dict = state_dict["state_dict"]
        generator = ConditionalGenerator(
            corpus, max_sentence_length=max_sentence_length)
        generator.load_state_dict(state_dict)
        return generator
if __name__ == "__main__":
    #MODEL = importlib.import_module(FLAGS.model_file) # import network module
    #MODEL_FILE = os.path.join(BASE_DIR, 'models', FLAGS.model_file+'.py')

    ####### log writing
    FLAGS.LOG_DIR = FLAGS.LOG_DIR + '/' + FLAGS.task_name
    #FLAGS.CHECKPOINT_DIR = os.path.join(FLAGS.CHECKPOINT_DIR, FLAGS.task_name)
    #tf_util.mkdir(FLAGS.CHECKPOINT_DIR)

    if not FLAGS.is_training:
        agent = ActiveMVnet(FLAGS)
        senv = ShapeNetEnv(FLAGS)
        if FLAGS.pretrain_restore:
            restore_pretrain(agent)
        else:
            restore_from_iter(agent, FLAGS.test_iter)
        replay_mem = ReplayMemory(FLAGS)
        rollout_obj = Rollout(agent, senv, replay_mem, FLAGS)
        if FLAGS.test_random:
            test_random(agent, FLAGS.test_episode_num, replay_mem,
                        FLAGS.test_iter, rollout_obj)
        elif FLAGS.test_oneway:
            test_oneway(agent, FLAGS.test_episode_num, replay_mem,
                        FLAGS.test_iter, rollout_obj)
        else:
            test_active(agent, FLAGS.test_episode_num, replay_mem,
                        FLAGS.test_iter, rollout_obj)

        sys.exit()
Esempio n. 18
0
    def __init__(self,
                 actions,
                 optimizer,
                 convs,
                 fcs,
                 padding,
                 lstm,
                 gamma=0.99,
                 lstm_unit=256,
                 time_horizon=5,
                 policy_factor=1.0,
                 value_factor=0.5,
                 entropy_factor=0.01,
                 grad_clip=40.0,
                 state_shape=[84, 84, 1],
                 buffer_size=2e3,
                 rp_frame=3,
                 phi=lambda s: s,
                 name='global'):
        self.actions = actions
        self.gamma = gamma
        self.name = name
        self.time_horizon = time_horizon
        self.state_shape = state_shape
        self.rp_frame = rp_frame
        self.phi = phi

        self._act,\
        self._train,\
        self._update_local = build_graph.build_train(
            convs=convs,
            fcs=fcs,
            padding=padding,
            lstm=lstm,
            num_actions=len(actions),
            optimizer=optimizer,
            lstm_unit=lstm_unit,
            state_shape=state_shape,
            grad_clip=grad_clip,
            policy_factor=policy_factor,
            value_factor=value_factor,
            entropy_factor=entropy_factor,
            rp_frame=rp_frame,
            scope=name
        )

        # rnn state variables
        self.initial_state = np.zeros((1, lstm_unit), np.float32)
        self.rnn_state0 = self.initial_state
        self.rnn_state1 = self.initial_state

        # last state variables
        self.zero_state = np.zeros(state_shape, dtype=np.float32)
        self.initial_last_obs = [self.zero_state for _ in range(rp_frame)]
        self.last_obs = deque(self.initial_last_obs, maxlen=rp_frame)
        self.last_action = deque([0, 0], maxlen=2)
        self.value_tm1 = None
        self.reward_tm1 = 0.0

        # buffers
        self.rollout = Rollout()
        self.buffer = ReplayBuffer(capacity=buffer_size)

        self.t = 0
        self.t_in_episode = 0
Esempio n. 19
0
    discriminator = discriminator.cuda()

# Pretrain Discriminator
dis_criterion = nn.NLLLoss(size_average=False)
dis_optimizer = optim.Adam(discriminator.parameters())
if opt.cuda:
    dis_criterion = dis_criterion.cuda()
print('Pretrain Discriminator ...')
for epoch in range(PRE_EPOCH_NUM):
    loss, acc = train_discriminator(discriminator, generators,
                                    real_data_iterator, dis_criterion,
                                    dis_optimizer)
    print('Epoch [{}], loss: {}, accuracy: {}'.format(epoch, loss, acc))

# # Adversarial Training
rollouts = [Rollout(generator, 0.8) for generator in generators]
print('#####################################################')
print('Start Adversarial Training...')
gen_gan_losses = [GANLoss() for _ in generators]
gen_gan_optm = [optim.Adam(generator.parameters()) for generator in generators]
if opt.cuda:
    gen_gan_loss = [gen_gan_loss.cuda() for gen_gan_loss in gen_gan_losses]
# gen_criterion = nn.NLLLoss(size_average=False)
# if opt.cuda:
# gen_criterion = gen_criterion.cuda()
# dis_criterion = nn.NLLLoss(size_average=False)
# dis_optimizer = optim.Adam(discriminator.parameters())
# if opt.cuda:
# dis_criterion = dis_criterion.cuda()
for total_batch in range(TOTAL_BATCH):
    for _ in range(3):
Esempio n. 20
0
def main(batch_size, num=None):
    if batch_size is None:
        batch_size = 1
    x, vocabulary, reverse_vocab, sentence_lengths = read_sampleFile(num=num)
    if batch_size > len(x):
        batch_size = len(x)
    start_token = vocabulary['START']
    end_token = vocabulary['END']
    pad_token = vocabulary['PAD']
    ignored_tokens = [start_token, end_token, pad_token]
    vocab_size = len(vocabulary)

    log = openLog()
    log.write("###### start to pretrain generator: {}\n".format(
        datetime.now()))
    log.close()
    generator = pretrain_generator(x,
                                   start_token=start_token,
                                   end_token=end_token,
                                   ignored_tokens=ignored_tokens,
                                   sentence_lengths=torch.tensor(
                                       sentence_lengths, device=DEVICE).long(),
                                   batch_size=batch_size,
                                   vocab_size=vocab_size)
    x_gen = generator.generate(start_token=start_token,
                               ignored_tokens=ignored_tokens,
                               batch_size=len(x))
    log = openLog()
    log.write("###### start to pretrain discriminator: {}\n".format(
        datetime.now()))
    log.close()
    discriminator = train_discriminator_wrapper(x, x_gen, batch_size,
                                                vocab_size)
    rollout = Rollout(generator, r_update_rate=0.8)
    rollout = torch.nn.DataParallel(rollout)  #, device_ids=[0])
    rollout.to(DEVICE)

    log = openLog()
    log.write("###### start to train adversarial net: {}\n".format(
        datetime.now()))
    log.close()
    for total_batch in range(TOTAL_BATCH):
        log = openLog()
        log.write('batch: {} : {}\n'.format(total_batch, datetime.now()))
        print('batch: {} : {}\n'.format(total_batch, datetime.now()))
        log.close()
        for it in range(1):
            samples = generator.generate(start_token=start_token,
                                         ignored_tokens=ignored_tokens,
                                         batch_size=batch_size)
            # Take average of ROLLOUT_ITER times of rewards.
            #   The more times a [0,1] class (positive, real data)
            #   is returned, the higher the reward.
            rewards = getReward(samples, rollout, discriminator)
            (generator, y_prob_all,
             y_output_all) = train_generator(model=generator,
                                             x=samples,
                                             reward=rewards,
                                             iter_n_gen=1,
                                             batch_size=batch_size,
                                             sentence_lengths=sentence_lengths)

        rollout.module.update_params(generator)

        for iter_n_dis in range(DIS_NUM_EPOCH):
            log = openLog()
            log.write('  iter_n_dis: {} : {}\n'.format(iter_n_dis,
                                                       datetime.now()))
            log.close()
            x_gen = generator.generate(start_token=start_token,
                                       ignored_tokens=ignored_tokens,
                                       batch_size=len(x))
            discriminator = train_discriminator_wrapper(
                x, x_gen, batch_size, vocab_size)

    log = openLog()
    log.write('###### training done: {}\n'.format(datetime.now()))
    log.close()

    torch.save(reverse_vocab, PATH + 'reverse_vocab.pkl')
    try:
        torch.save(generator, PATH + 'generator.pkl')
        print('successfully saved generator model.')
    except:
        print('error: model saving failed!!!!!!')

    log = openLog('genTxt.txt')
    num = generator.generate(batch_size=batch_size)
    log.close()
Esempio n. 21
0
def train(active_mv):

    senv = ShapeNetEnv(FLAGS)
    replay_mem = ReplayMemory(FLAGS)

    #### for debug
    #a = np.array([[1,0,1],[0,0,0]])
    #b = np.array([[1,0,1],[0,1,0]])
    #print('IoU: {}'.format(replay_mem.calu_IoU(a, b)))
    #sys.exit()
    #### for debug

    log_string('====== Starting burning in memories ======')
    burn_in(senv, replay_mem)
    log_string('====== Done. {} trajectories burnt in ======'.format(
        FLAGS.burn_in_length))

    #epsilon = FLAGS.init_eps
    K_single = np.asarray([[420.0, 0.0, 112.0], [0.0, 420.0, 112.0],
                           [0.0, 0.0, 1]])
    K_list = np.tile(K_single[None, None, ...],
                     (1, FLAGS.max_episode_length, 1, 1))

    rollout_obj = Rollout(active_mv, senv, replay_mem, FLAGS)
    ### burn in(pretrain) for MVnet
    if FLAGS.burn_in_iter > 0:
        for i in xrange(FLAGS.burnin_start_iter,
                        FLAGS.burnin_start_iter + FLAGS.burn_in_iter):
            rollout_obj.go(i,
                           verbose=True,
                           add_to_mem=True,
                           mode='random',
                           is_train=True)
            if not FLAGS.random_pretrain:
                replay_mem.enable_gbl()
                mvnet_input = replay_mem.get_batch_list(FLAGS.batch_size)
            else:
                mvnet_input = replay_mem.get_batch_list_random(
                    senv, FLAGS.batch_size)
            tic = time.time()
            out_stuff = active_mv.run_step(mvnet_input,
                                           mode='burnin',
                                           is_training=True)
            burnin_summ = burnin_log(i, out_stuff, time.time() - tic)
            active_mv.train_writer.add_summary(burnin_summ, i)

            if (i + 1) % 5000 == 0 and i > FLAGS.burnin_start_iter:
                save_pretrain(active_mv, i + 1)

            if (i + 1) % 1000 == 0 and i > FLAGS.burnin_start_iter:
                evaluate_burnin(active_mv,
                                FLAGS.test_episode_num,
                                replay_mem,
                                i + 1,
                                rollout_obj,
                                mode='random')

    for i_idx in xrange(FLAGS.max_iter):

        t0 = time.time()

        rollout_obj.go(i_idx, verbose=True, add_to_mem=True, mode='random')
        t1 = time.time()

        replay_mem.enable_gbl()
        mvnet_input = replay_mem.get_batch_list(FLAGS.batch_size)
        t2 = time.time()

        out_stuff = active_mv.run_step(mvnet_input,
                                       mode='train_mv',
                                       is_training=True)
        replay_mem.disable_gbl()
        t3 = time.time()

        train_log(i_idx, out_stuff, (t0, t1, t2, t3))

        active_mv.train_writer.add_summary(out_stuff.merged_train, i_idx)

        if i_idx % FLAGS.save_every_step == 0 and i_idx > 0:
            save(active_mv, i_idx, i_idx, i_idx)

        if i_idx % FLAGS.test_every_step == 0 and i_idx > 0:
            #print('Evaluating active policy')
            #evaluate(active_mv, FLAGS.test_episode_num, replay_mem, i_idx, rollout_obj, mode='active')
            print('Evaluating random policy')
            evaluate(active_mv,
                     FLAGS.test_episode_num,
                     replay_mem,
                     i_idx,
                     rollout_obj,
                     mode='random')
Esempio n. 22
0
def main(opt):

    cuda = opt.cuda
    visualize = opt.visualize
    print(f"cuda = {cuda}, visualize = {opt.visualize}")
    if visualize:
        if PRE_EPOCH_GEN > 0:
            pretrain_G_score_logger = VisdomPlotLogger(
                'line', opts={'title': 'Pre-train G Goodness Score'})
        if PRE_EPOCH_DIS > 0:
            pretrain_D_loss_logger = VisdomPlotLogger(
                'line', opts={'title': 'Pre-train D Loss'})
        adversarial_G_score_logger = VisdomPlotLogger(
            'line',
            opts={
                'title': f'Adversarial G {GD} Goodness Score',
                'Y': '{0, 13}',
                'X': '{0, TOTAL_BATCH}'
            })
        if CHECK_VARIANCE:
            G_variance_logger = VisdomPlotLogger(
                'line', opts={'title': f'Adversarial G {GD} Variance'})
        G_text_logger = VisdomTextLogger(update_type='APPEND')
        adversarial_D_loss_logger = VisdomPlotLogger(
            'line', opts={'title': 'Adversarial Batch D Loss'})

    # Define Networks
    generator = Generator(VOCAB_SIZE, g_emb_dim, g_hidden_dim, cuda)
    n_gen = Variable(torch.Tensor([get_n_params(generator)]))
    use_cuda = False
    if cuda:
        n_gen = n_gen.cuda()
        use_cuda = True
    print('Number of parameters in the generator: {}'.format(n_gen))
    discriminator = LSTMDiscriminator(d_num_class, VOCAB_SIZE,
                                      d_lstm_hidden_dim, use_cuda)
    c_phi_hat = AnnexNetwork(d_num_class, VOCAB_SIZE, d_emb_dim,
                             c_filter_sizes, c_num_filters, d_dropout,
                             BATCH_SIZE, g_sequence_len)
    if cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        c_phi_hat = c_phi_hat.cuda()

    # Generate toy data using target lstm
    print('Generating data ...')

    # Load data from file
    gen_data_iter = DataLoader(POSITIVE_FILE, BATCH_SIZE)

    gen_criterion = nn.NLLLoss(size_average=False)
    gen_optimizer = optim.Adam(generator.parameters())
    if cuda:
        gen_criterion = gen_criterion.cuda()
    # 预训练Generator
    # Pretrain Generator using MLE
    pre_train_scores = []
    if MLE:
        print('Pretrain with MLE ...')
        for epoch in range(int(np.ceil(PRE_EPOCH_GEN))):
            loss = train_epoch(generator, gen_data_iter, gen_criterion,
                               gen_optimizer, PRE_EPOCH_GEN, epoch, cuda)
            print('Epoch [%d] Model Loss: %f' % (epoch, loss))
            samples = generate_samples(generator, BATCH_SIZE, GENERATED_NUM,
                                       EVAL_FILE)
            eval_iter = DataLoader(EVAL_FILE, BATCH_SIZE)
            generated_string = eval_iter.convert_to_char(samples)
            print(generated_string)
            eval_score = get_data_goodness_score(generated_string, SPACES)
            if SPACES == False:
                kl_score = get_data_freq(generated_string)
            else:
                kl_score = -1
            freq_score = get_char_freq(generated_string, SPACES)
            pre_train_scores.append(eval_score)
            print('Epoch [%d] Generation Score: %f' % (epoch, eval_score))
            print('Epoch [%d] KL Score: %f' % (epoch, kl_score))
            print('Epoch [{}] Character distribution: {}'.format(
                epoch, list(freq_score)))

            torch.save(
                generator.state_dict(),
                f"checkpoints/MLE_space_{SPACES}_length_{SEQ_LEN}_preTrainG_epoch_{epoch}.pth"
            )

            if visualize:
                pretrain_G_score_logger.log(epoch, eval_score)
    else:
        generator.load_state_dict(torch.load(weights_path))

    # Finishing training with MLE
    if GD == "MLE":
        for epoch in range(3 * int(GENERATED_NUM / BATCH_SIZE)):
            loss = train_epoch_batch(generator, gen_data_iter, gen_criterion,
                                     gen_optimizer, PRE_EPOCH_GEN, epoch,
                                     int(GENERATED_NUM / BATCH_SIZE), cuda)
            print('Epoch [%d] Model Loss: %f' % (epoch, loss))
            samples = generate_samples(generator, BATCH_SIZE, GENERATED_NUM,
                                       EVAL_FILE)
            eval_iter = DataLoader(EVAL_FILE, BATCH_SIZE)
            generated_string = eval_iter.convert_to_char(samples)
            print(generated_string)
            eval_score = get_data_goodness_score(generated_string, SPACES)
            if SPACES == False:
                kl_score = get_data_freq(generated_string)
            else:
                kl_score = -1
            freq_score = get_char_freq(generated_string, SPACES)
            pre_train_scores.append(eval_score)
            print('Epoch [%d] Generation Score: %f' % (epoch, eval_score))
            print('Epoch [%d] KL Score: %f' % (epoch, kl_score))
            print('Epoch [{}] Character distribution: {}'.format(
                epoch, list(freq_score)))

            torch.save(
                generator.state_dict(),
                f"checkpoints/MLE_space_{SPACES}_length_{SEQ_LEN}_preTrainG_epoch_{epoch}.pth"
            )

            if visualize:
                pretrain_G_score_logger.log(epoch, eval_score)
    # 预训练Discriminator
    # Pretrain Discriminator
    dis_criterion = nn.NLLLoss(size_average=False)
    dis_optimizer = optim.Adam(discriminator.parameters())
    if opt.cuda:
        dis_criterion = dis_criterion.cuda()
    print('Pretrain Discriminator ...')
    for epoch in range(PRE_EPOCH_DIS):
        samples = generate_samples(generator, BATCH_SIZE, GENERATED_NUM,
                                   NEGATIVE_FILE)
        dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE, BATCH_SIZE,
                                    SEQ_LEN)
        for _ in range(PRE_ITER_DIS):
            loss = train_epoch(discriminator, dis_data_iter, dis_criterion,
                               dis_optimizer, 1, 1, cuda)
            print('Epoch [%d], loss: %f' % (epoch, loss))
            if visualize:
                pretrain_D_loss_logger.log(epoch, loss)
    # 对抗训练
    # Adversarial Training
    rollout = Rollout(generator, UPDATE_RATE)
    print('#####################################################')
    print('Start Adversarial Training...\n')

    gen_gan_loss = GANLoss()
    gen_gan_optm = optim.Adam(generator.parameters())
    if cuda:
        gen_gan_loss = gen_gan_loss.cuda()
    gen_criterion = nn.NLLLoss(size_average=False)
    if cuda:
        gen_criterion = gen_criterion.cuda()

    dis_criterion = nn.NLLLoss(size_average=False)
    dis_criterion_bce = nn.BCELoss()
    dis_optimizer = optim.Adam(discriminator.parameters())
    if cuda:
        dis_criterion = dis_criterion.cuda()

    c_phi_hat_loss = VarianceLoss()
    if cuda:
        c_phi_hat_loss = c_phi_hat_loss.cuda()
    c_phi_hat_optm = optim.Adam(c_phi_hat.parameters())

    gen_scores = pre_train_scores

    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(G_STEPS):
            samples = generator.sample(BATCH_SIZE, g_sequence_len)
            # samples has size (BS, sequence_len)
            # Construct the input to the generator, add zeros before samples and delete the last column
            zeros = torch.zeros((BATCH_SIZE, 1)).type(torch.LongTensor)
            if samples.is_cuda:
                zeros = zeros.cuda()
            inputs = Variable(
                torch.cat([zeros, samples.data], dim=1)[:, :-1].contiguous())
            targets = Variable(samples.data).contiguous().view((-1, ))
            if opt.cuda:
                inputs = inputs.cuda()
                targets = targets.cuda()
            # Calculate the reward
            rewards = rollout.get_reward(samples, discriminator, VOCAB_SIZE,
                                         cuda)
            rewards = Variable(torch.Tensor(rewards))
            if cuda:
                rewards = torch.exp(rewards.cuda()).contiguous().view((-1, ))
            rewards = torch.exp(rewards)
            # rewards has size (BS)
            prob = generator.forward(inputs)
            # prob has size (BS*sequence_len, VOCAB_SIZE)
            # 3.a
            theta_prime = g_output_prob(prob)
            # theta_prime has size (BS*sequence_len, VOCAB_SIZE)
            # 3.e and f
            c_phi_z_ori, c_phi_z_tilde_ori = c_phi_out(
                GD,
                c_phi_hat,
                theta_prime,
                discriminator,
                temperature=DEFAULT_TEMPERATURE,
                eta=DEFAULT_ETA,
                cuda=cuda)
            c_phi_z_ori = torch.exp(c_phi_z_ori)
            c_phi_z_tilde_ori = torch.exp(c_phi_z_tilde_ori)
            c_phi_z = torch.sum(c_phi_z_ori[:, 1]) / BATCH_SIZE
            c_phi_z_tilde = -torch.sum(c_phi_z_tilde_ori[:, 1]) / BATCH_SIZE
            if opt.cuda:
                c_phi_z = c_phi_z.cuda()
                c_phi_z_tilde = c_phi_z_tilde.cuda()
                c_phi_hat = c_phi_hat.cuda()
            # 3.i
            grads = []
            first_term_grads = []
            # 3.h optimization step
            # first, empty the gradient buffers
            gen_gan_optm.zero_grad()
            # first, re arrange prob
            new_prob = prob.view((BATCH_SIZE, g_sequence_len, VOCAB_SIZE))
            # 3.g new gradient loss for relax
            batch_i_grads_1 = gen_gan_loss.forward_reward_grads(
                samples, new_prob, rewards, generator, BATCH_SIZE,
                g_sequence_len, VOCAB_SIZE, cuda)
            batch_i_grads_2 = gen_gan_loss.forward_reward_grads(
                samples, new_prob, c_phi_z_tilde_ori[:, 1], generator,
                BATCH_SIZE, g_sequence_len, VOCAB_SIZE, cuda)
            # batch_i_grads_1 and batch_i_grads_2 should be of length BATCH SIZE of arrays of all the gradients
            # # 3.i
            batch_grads = batch_i_grads_1
            if GD != "REINFORCE":
                for i in range(len(batch_i_grads_1)):
                    for j in range(len(batch_i_grads_1[i])):
                        batch_grads[i][j] = torch.add(batch_grads[i][j], (-1) *
                                                      batch_i_grads_2[i][j])
            # batch_grads should be of length BATCH SIZE
            grads.append(batch_grads)
            # NOW, TRAIN THE GENERATOR
            generator.zero_grad()
            for i in range(g_sequence_len):
                # 3.g new gradient loss for relax
                cond_prob = gen_gan_loss.forward_reward(
                    i, samples, new_prob, rewards, BATCH_SIZE, g_sequence_len,
                    VOCAB_SIZE, cuda)
                c_term = gen_gan_loss.forward_reward(i, samples, new_prob,
                                                     c_phi_z_tilde_ori[:, 1],
                                                     BATCH_SIZE,
                                                     g_sequence_len,
                                                     VOCAB_SIZE, cuda)
                if GD != "REINFORCE":
                    cond_prob = torch.add(cond_prob, (-1) * c_term)
                new_prob[:, i, :].backward(cond_prob, retain_graph=True)
            # 3.h - still training the generator, with the last two terms of the RELAX equation
            if GD != "REINFORCE":
                c_phi_z.backward(retain_graph=True)
                c_phi_z_tilde.backward(retain_graph=True)
            gen_gan_optm.step()
            # 3.i
            if CHECK_VARIANCE:
                # c_phi_z term
                partial_grads = []
                for j in range(BATCH_SIZE):
                    generator.zero_grad()
                    c_phi_z_ori[j, 1].backward(retain_graph=True)
                    j_grads = []
                    for p in generator.parameters():
                        j_grads.append(p.grad.clone())
                    partial_grads.append(j_grads)
                grads.append(partial_grads)
                # c_phi_z_tilde term
                partial_grads = []
                for j in range(BATCH_SIZE):
                    generator.zero_grad()
                    c_phi_z_tilde_ori[j, 1].backward(retain_graph=True)
                    j_grads = []
                    for p in generator.parameters():
                        j_grads.append(-1 * p.grad.clone())
                    partial_grads.append(j_grads)
                grads.append(partial_grads)
                # Uncomment the below code if you want to check gradients
                """
                print('1st contribution to the gradient')
                print(grads[0][0][6])
                print('2nd contribution to the gradient')
                print(grads[1][0][6])
                print('3rd contribution to the gradient')
                print(grads[2][0][6])
                """
                #grads should be of length 3
                #grads[0] should be of length BATCH SIZE
                # 3.j
                all_grads = grads[0]
                if GD != "REINFORCE":
                    for i in range(len(grads[0])):
                        for j in range(len(grads[0][i])):
                            all_grads[i][j] = torch.add(
                                torch.add(all_grads[i][j], grads[1][i][j]),
                                grads[2][i][j])
                # all_grads should be of length BATCH_SIZE
                c_phi_hat_optm.zero_grad()
                var_loss = c_phi_hat_loss.forward(all_grads, cuda)  #/n_gen
                true_variance = c_phi_hat_loss.forward_variance(
                    all_grads, cuda)
                var_loss.backward()
                c_phi_hat_optm.step()
                print(
                    'Batch [{}] Estimate of the variance of the gradient at step {}: {}'
                    .format(total_batch, it, true_variance[0]))
                if visualize:
                    G_variance_logger.log((total_batch + it), true_variance[0])

        # Evaluate the quality of the Generator outputs
        if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
            samples = generate_samples(generator, BATCH_SIZE, GENERATED_NUM,
                                       EVAL_FILE)
            eval_iter = DataLoader(EVAL_FILE, BATCH_SIZE)
            generated_string = eval_iter.convert_to_char(samples)
            print(generated_string)
            eval_score = get_data_goodness_score(generated_string, SPACES)
            if SPACES == False:
                kl_score = get_data_freq(generated_string)
            else:
                kl_score = -1
            freq_score = get_char_freq(generated_string, SPACES)
            gen_scores.append(eval_score)
            print('Batch [%d] Generation Score: %f' %
                  (total_batch, eval_score))
            print('Batch [%d] KL Score: %f' % (total_batch, kl_score))
            print('Epoch [{}] Character distribution: {}'.format(
                total_batch, list(freq_score)))

            #Checkpoint & Visualize
            if total_batch % 10 == 0 or total_batch == TOTAL_BATCH - 1:
                torch.save(
                    generator.state_dict(),
                    f'checkpoints/{GD}_G_space_{SPACES}_pretrain_{PRE_EPOCH_GEN}_batch_{total_batch}.pth'
                )
            if visualize:
                [G_text_logger.log(line) for line in generated_string]
                adversarial_G_score_logger.log(total_batch, eval_score)

        # Train the discriminator
        batch_G_loss = 0.0

        for b in range(D_EPOCHS):

            for data, _ in gen_data_iter:

                data = Variable(data)
                real_data = convert_to_one_hot(data, VOCAB_SIZE, cuda)
                real_target = Variable(torch.ones((data.size(0), 1)))
                samples = generator.sample(data.size(0),
                                           g_sequence_len)  # bs x seq_len
                fake_data = convert_to_one_hot(
                    samples, VOCAB_SIZE, cuda)  # bs x seq_len x vocab_size
                fake_target = Variable(torch.zeros((data.size(0), 1)))

                if cuda:
                    real_target = real_target.cuda()
                    fake_target = fake_target.cuda()
                    real_data = real_data.cuda()
                    fake_data = fake_data.cuda()

                real_pred = torch.exp(discriminator(real_data)[:, 1])
                fake_pred = torch.exp(discriminator(fake_data)[:, 1])

                D_real_loss = dis_criterion_bce(real_pred, real_target)
                D_fake_loss = dis_criterion_bce(fake_pred, fake_target)
                D_loss = D_real_loss + D_fake_loss
                dis_optimizer.zero_grad()
                D_loss.backward()
                dis_optimizer.step()

            gen_data_iter.reset()

            print('Batch [{}] Discriminator Loss at step and epoch {}: {}'.
                  format(total_batch, b, D_loss.data[0]))

        if visualize:
            adversarial_D_loss_logger.log(total_batch, D_loss.data[0])

    if not visualize:
        plt.plot(gen_scores)
        plt.ylim((0, 13))
        plt.title('{}_after_{}_epochs_of_pretraining'.format(
            GD, PRE_EPOCH_GEN))
        plt.show()
Esempio n. 23
0
class MultitaskPolicy(object):
	def __init__(
			self,
			env,
			env_config,
			policy,
			value,
			oracle_network,
			share_exp,
			oracle):

		self.env_config = env_config
		self.PGNetwork = policy
		self.VNetwork = value
		self.ZNetwork = oracle_network
		self.rollout = Rollout(number_episode=args.rollouts, map_index=env_config.map_index)

		self.oracle = oracle
		self.share_exp = share_exp
		self.env = env

	@staticmethod
	def _get_state_sharing_action(state_index, current_oracle, task, other_task, share_dict):
		share_action = np.random.choice(range(len(current_oracle[task, other_task, state_index])),
										p=np.array(current_oracle[task, other_task, state_index]) / sum(
											current_oracle[task, other_task, state_index]))
		if args.oracle_sharing:
			share_action = share_dict[state_index][task][other_task]
		return share_action

	@staticmethod
	def _clip_important_weight(weight):
		if weight > 1.2:
			weight = 1.2
		if weight < 0.8:
			weight = 0.8
		return weight

	def _process_PV_batch(self, states, actions, drewards, GAEs, next_states, current_policy, current_value, current_oracle, count_dict):
		batch_ss, batch_as, batch_Qs, batch_Ts = [], [], [], []
		share_ss, share_as, share_Ts = [], [], []
		for task in range(self.env.num_task):
			batch_ss.append([])
			batch_as.append([])
			batch_Qs.append([])
			batch_Ts.append([])
			share_ss.append([])
			share_as.append([])
			share_Ts.append([])

		share_dict = {}
		mean_sa_dict = {}

		for current_task in range(self.env.num_task):
			for i, (state, action, actual_value, gae, next_state) in enumerate(zip(states[current_task], actions[current_task], drewards[current_task], GAEs[current_task], next_states[current_task])):
				state_index = state[0]+(state[1]-1)*self.env.bounds_x[1]-1

				advantage = actual_value - current_value[task, state_index]
				if args.use_gae:
					advantage = gae	
					
				if self.share_exp:
					if share_dict.get(state_index, -1) == -1:
						share_dict[state_index] = get_state_sharing_info(self.env, state, state_index, args.oracle_sharing, current_oracle)
					if mean_sa_dict.get((state_index, action), -1) == -1:
						mean_sa_dict[state_index, action] = calculate_mean_pi_sa_task(self.env, state, action, state_index, share_dict, count_dict, current_policy, current_oracle, args.oracle_sharing)

					for other_task in range(self.env.num_task):

						share_action = MultitaskPolicy._get_state_sharing_action(state_index, current_oracle, current_task, other_task, share_dict)
						if share_action == 1:
							important_weight = current_policy[other_task, state_index][action]/mean_sa_dict[state_index, action][other_task]
							clip_important_weight = MultitaskPolicy._clip_important_weight(important_weight)

							if (0.8 <= important_weight <= 1.2) or (clip_important_weight*advantage > important_weight * advantage):
								if other_task == current_task:
									# use current_task's experience to update itself
									batch_ss[current_task].append(self.env.cv_state_onehot[state_index].tolist())
									batch_as[current_task].append(self.env.cv_action_onehot[action].tolist())
									batch_Qs[current_task].append(actual_value)
									batch_Ts[current_task].append(important_weight*advantage)
								else:
									# use current_task's experience to update other_task
									share_ss[other_task].append(self.env.cv_state_onehot[state_index].tolist())
									share_as[other_task].append(self.env.cv_action_onehot[action].tolist())
									share_Ts[other_task].append(important_weight * advantage)
				else:
					batch_ss[current_task].append(self.env.cv_state_onehot[state_index].tolist())
					batch_as[current_task].append(self.env.cv_action_onehot[action].tolist())
					batch_Qs[current_task].append(actual_value)
					batch_Ts[current_task].append(advantage)

				#############################################################################################################

		return batch_ss, batch_as, batch_Qs, batch_Ts, share_ss, share_as, share_Ts

	def _process_Z_batch(self, state_dict, count_dict):
		z_ss, z_as, z_rs = {}, {}, {}
		for i in range(self.env.num_task - 1):
			for j in range(i + 1, self.env.num_task):
				z_ss[i, j] = []
				z_as[i, j] = []
				z_rs[i, j] = []
		for v in state_dict.keys():
			for i in range(self.env.num_task - 1):
				for j in range(i + 1, self.env.num_task):
					for action in range(self.env.action_size):

						z_reward = 0.0
						if state_dict[v][i][action] * state_dict[v][j][action] > 0:
							z_reward = min(abs(state_dict[v][i][action]), abs(state_dict[v][j][action]))
							z_action = [0, 1]

						if state_dict[v][i][action] * state_dict[v][j][action] < 0:
							z_reward = min(abs(state_dict[v][i][action]), abs(state_dict[v][j][action]))
							z_action = [1, 0]

						if sum(count_dict[v][i]) == 0 and sum(count_dict[v][j]) > 0:
							z_reward = 0.001
							z_action = [1, 0]

						if sum(count_dict[v][j]) == 0 and sum(count_dict[v][i]) > 0:
							z_reward = 0.001
							z_action = [1, 0]

						if z_reward > 0.0:
							z_ss[i, j].append(self.env.cv_state_onehot[v].tolist())
							z_as[i, j].append(z_action)
							z_rs[i, j].append(z_reward)
		return z_ss, z_as, z_rs

	def _make_batch(self, epoch):
		current_policy, current_value, current_oracle = get_current_policy(self.env, self.PGNetwork, self.VNetwork, self.ZNetwork)

		# states = [
		#task1		[[---episode_1---],...,[---episode_n---]],
		#task2		[[---episode_1---],...,[---episode_n---]]
		#			]
		states, tasks, actions, rewards, next_states = self.rollout.rollout_batch(self.PGNetwork, current_policy, epoch)

		discounted_rewards, GAEs = [], []
		for task in range(self.env.num_task):
			discounted_rewards.append([])
			GAEs.append([])
			for ep_state, ep_next, ep_reward in zip(states[task], next_states[task], rewards[task]):	
				discounted_rewards[task] += discount_rewards(self.env, ep_reward, ep_state, ep_next, task, current_value)
				GAEs[task] += GAE(self.env, ep_reward, ep_state, ep_next, task, current_value)
			
			states[task] = np.concatenate(states[task])       
			tasks[task] = np.concatenate(tasks[task])     
			actions[task] = np.concatenate(actions[task])     
			rewards[task] = np.concatenate(rewards[task])
			next_states[task] = np.concatenate(next_states[task])

		state_dict, count_dict = statistic(self.env, states, actions, discounted_rewards, GAEs, next_states, current_value)
		task_states, task_actions, task_target_values, task_advantages, \
		sharing_states, sharing_actions, sharing_advantages = self._process_PV_batch(states,
																			  actions,
																			  discounted_rewards,
																			  GAEs,
																			  next_states,
																			  current_policy,
																			  current_value,
																			  current_oracle,
																			  count_dict)

		z_states, z_actions, z_rewards = self._process_Z_batch(state_dict, count_dict)

		return task_states, task_actions, task_target_values, task_advantages, \
			   sharing_states, sharing_actions, sharing_advantages, \
			   np.concatenate(rewards), \
			   z_states, z_actions, z_rewards

	def update_value_function(self, states, target_value):
		for t in range(self.env.num_task):
			self.VNetwork[t].update_parameters(Tensor(states[t]), Tensor(target_value[t]))

	def update_sharing_Z_agent(self, states, actions, rewards):
		for i in range(self.env.num_task - 1):
			for j in range(i + 1, self.env.num_task):
				if len(states[i, j]) > 0:
					self.ZNetwork[i, j].update_parameters(Tensor(states[i, j]), Tensor(rewards[i, j]), Tensor(actions[i, j]))

	def update_policy(self, states, actions, advantage):
		for t in range(self.env.num_task):
			self.PGNetwork[t].update_parameters(Tensor(states[t]), Tensor(advantage[t]), Tensor(actions[t]))

	def train(self):
		evarage_samples = 0
		for epoch in range(args.epochs):
			print(epoch)
			task_states, task_actions, task_target_values, task_advantages, \
			sharing_states, sharing_actions, sharing_advantages, \
			recorded_rewards, \
			z_states, z_actions, z_rewards = self._make_batch(epoch)

			self.update_value_function(task_states,  task_target_values)

			if self.share_exp:
				self.update_sharing_Z_agent(z_states, z_actions, z_rewards)

				for task_index in range(self.env.num_task):
					task_states[task_index] += sharing_states[task_index]
					task_actions[task_index] += sharing_actions[task_index]
					task_advantages[task_index] += sharing_advantages[task_index]

			self.update_policy(task_states, task_actions, task_advantages)

			# WRITE TF SUMMARIES
			evarage_samples += len(recorded_rewards) / self.env.num_task
			total_reward_of_that_batch = np.sum(recorded_rewards) / self.env.num_task
			mean_reward_of_that_batch = np.divide(total_reward_of_that_batch, args.rollouts)
Esempio n. 24
0
                             pretrain_epochs=PRETRAIN_EPOCHS,
                             songs=songs,
                             char2idx=char2idx,
                             idx2char=idx2char,
                             tb_writer=train_summary_writer,
                             learning_rate=1e-4)
    print('Start pre-training discriminator...')
    disc_pre_trainer.pretrain(save_disc_weights)




rollout = Rollout(  generator=gen,
                    discriminator=disc,
                    batch_size=batch_size,
                    embedding_size=embedding_dim,
                    sequence_length=seq_len,
                    start_token=start_token,
                    rollout_num=rollout_num)


for epoch in range(EPOCHS):
    fake_samples = gen.generate()
    rewards = rollout.get_reward(samples=fake_samples)
    gen_loss = gen.train_step(fake_samples, rewards)
    real_samples, _ = get_batch(seq_len, batch_size)
    disc_loss = 0
    for i in range(disc_steps):
        disc_loss += disc.train_step(fake_samples, real_samples)/disc_steps

    with train_summary_writer.as_default():
Esempio n. 25
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    # Define Networks
    generator = Generator(VOCAB_SIZE, g_emb_dim, g_hidden_dim, opt.cuda)
    discriminator = Discriminator(d_num_class, VOCAB_SIZE, d_emb_dim,
                                  d_filter_sizes, d_num_filters, d_dropout)
    target_lstm = TargetLSTM(VOCAB_SIZE, g_emb_dim, g_hidden_dim, opt.cuda)
    if opt.cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        target_lstm = target_lstm.cuda()
    # Generate toy data using target lstm
    print('Generating data ...')
    generate_samples(target_lstm, BATCH_SIZE, GENERATED_NUM, POSITIVE_FILE)

    # Load data from file
    gen_data_iter = GenDataIter(POSITIVE_FILE, BATCH_SIZE)

    # Pretrain Generator using MLE
    gen_criterion = nn.NLLLoss(size_average=False)
    gen_optimizer = optim.Adam(generator.parameters())
    if opt.cuda:
        gen_criterion = gen_criterion.cuda()
    print('Pretrain with MLE ...')
    for epoch in range(PRE_EPOCH_NUM):
        loss = train_epoch(generator, gen_data_iter, gen_criterion,
                           gen_optimizer)
        print('Epoch [%d] Model Loss: %f' % (epoch, loss))
        generate_samples(generator, BATCH_SIZE, GENERATED_NUM, EVAL_FILE)
        eval_iter = GenDataIter(EVAL_FILE, BATCH_SIZE)
        loss = eval_epoch(target_lstm, eval_iter, gen_criterion)
        print('Epoch [%d] True Loss: %f' % (epoch, loss))

    # Pretrain Discriminator
    dis_criterion = nn.NLLLoss(size_average=False)
    dis_optimizer = optim.Adam(discriminator.parameters())
    if opt.cuda:
        dis_criterion = dis_criterion.cuda()
    print('Pretrain Dsicriminator ...')
    for epoch in range(5):
        generate_samples(generator, BATCH_SIZE, GENERATED_NUM, NEGATIVE_FILE)
        dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE, BATCH_SIZE)
        for _ in range(3):
            loss = train_epoch(discriminator, dis_data_iter, dis_criterion,
                               dis_optimizer)
            print('Epoch [%d], loss: %f' % (epoch, loss))
    # Adversarial Training
    rollout = Rollout(generator, 0.8)
    print('#####################################################')
    print('Start Adeversatial Training...\n')
    gen_gan_loss = GANLoss()
    gen_gan_optm = optim.Adam(generator.parameters())
    if opt.cuda:
        gen_gan_loss = gen_gan_loss.cuda()
    gen_criterion = nn.NLLLoss(size_average=False)
    if opt.cuda:
        gen_criterion = gen_criterion.cuda()
    dis_criterion = nn.NLLLoss(size_average=False)
    dis_optimizer = optim.Adam(discriminator.parameters())
    if opt.cuda:
        dis_criterion = dis_criterion.cuda()
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            samples = generator.sample(BATCH_SIZE, g_sequence_len)
            # construct the input to the generator, add zeros before samples and delete the last column
            zeros = torch.zeros((BATCH_SIZE, 1)).type(torch.LongTensor)
            if samples.is_cuda:
                zeros = zeros.cuda()
            inputs = Variable(
                torch.cat([zeros, samples.data], dim=1)[:, :-1].contiguous())
            targets = Variable(samples.data).contiguous().view((-1, ))
            # calculate the reward
            rewards = rollout.get_reward(samples, 16, discriminator)
            rewards = Variable(torch.Tensor(rewards))
            if opt.cuda:
                rewards = torch.exp(rewards.cuda()).contiguous().view((-1, ))
            prob = generator.forward(inputs)
            loss = gen_gan_loss(prob, targets, rewards)
            gen_gan_optm.zero_grad()
            loss.backward()
            gen_gan_optm.step()

        if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(generator, BATCH_SIZE, GENERATED_NUM, EVAL_FILE)
            eval_iter = GenDataIter(EVAL_FILE, BATCH_SIZE)
            loss = eval_epoch(target_lstm, eval_iter, gen_criterion)
            print('Batch [%d] True Loss: %f' % (total_batch, loss))
        rollout.update_params()

        for _ in range(4):
            generate_samples(generator, BATCH_SIZE, GENERATED_NUM,
                             NEGATIVE_FILE)
            dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE,
                                        BATCH_SIZE)
            for _ in range(2):
                loss = train_epoch(discriminator, dis_data_iter, dis_criterion,
                                   dis_optimizer)
Esempio n. 26
0
class PPOAgent():
    def __init__(self,
                 state_size,
                 action_size,
                 lr=1e-3,
                 gamma=0.99,
                 clipping_epsilon=0.1,
                 ppo_epochs=10,
                 minibatch_size=64,
                 rollout_length=1000,
                 gae_lambda=0.95):
        self.lr = lr
        self.clipping_epsilon = clipping_epsilon
        self.ppo_epochs = ppo_epochs
        self.minibatch_size = minibatch_size
        self.rollout_length = rollout_length

        self.policy = PolicyNet(state_size, action_size)
        self.value_estimator = ValueNet(state_size)
        self.rollout = Rollout(gamma=gamma, gae_lambda=gae_lambda)

    def start_episode(self):
        self.episode_rewards = []
        self.rollout.start_rollout()

    def act(self, state):

        # Check if the rollout is full and needs processing
        if len(self.rollout) == self.rollout_length:
            self.learn()
            self.rollout.start_rollout()

        # Derive action distribution from policy network
        means, sigmas = self.policy(state)
        action_distribution = torch.distributions.Normal(means, sigmas)
        action = action_distribution.sample()
        action_log_prob = action_distribution.log_prob(action)

        # Derive state value estimate from value network
        state_value = self.value_estimator(state).squeeze()

        # Record decision and return sampled action
        self.rollout.record_decision(state, state_value, action,
                                     action_log_prob)
        return action

    def finish_episode(self):
        self.learn()

    def record_outcome(self, reward):
        self.episode_rewards.append(reward)
        self.rollout.record_outcome(reward)

    def average_episode_return(self):
        return sum([r.mean().item() for r in self.episode_rewards])

    def get_current_policy_probs(self, states, actions):

        # For the given state/action pairs, create a distribution from the policy and get the log probs
        means, sigmas = self.policy(states)
        action_distribution = torch.distributions.Normal(means, sigmas)
        current_policy_log_probs = action_distribution.log_prob(actions)

        # Sum log probs over the possible actions
        current_policy_log_probs = current_policy_log_probs.sum(-1)

        return torch.exp(current_policy_log_probs)

    def learn(self):

        (states, actions, future_returns, normalised_advantages, original_policy_probs) = \
            self.rollout.flatten_trajectories()

        # Run through PPO epochs
        policy_optimiser = optim.Adam(self.policy.parameters(),
                                      lr=self.lr,
                                      eps=1e-5)
        value_estimator_optimiser = optim.Adam(
            self.value_estimator.parameters(), lr=self.lr, eps=1e-5)
        for ppo_epoch in range(self.ppo_epochs):

            # Sample the trajectories randomly in mini-batches
            for indices in random_sample(np.arange(states.shape[0]),
                                         self.minibatch_size):

                # Sample using sample indices
                states_sample = states[indices]
                actions_sample = actions[indices]
                future_returns_sample = future_returns[indices]
                normalised_advantages_sample = normalised_advantages[indices]
                original_policy_probs_sample = original_policy_probs[indices]

                # Use the current policy to get the probabilities for the sample states and actions
                # We use these to weight the likehoods, allowing resuse of the rollout
                current_policy_probs_sample = self.get_current_policy_probs(
                    states_sample, actions_sample)

                # Define PPO surrogate and clip to get the policy loss
                sampling_ratio = current_policy_probs_sample / original_policy_probs_sample
                clipped_ratio = torch.clamp(sampling_ratio,
                                            1 - self.clipping_epsilon,
                                            1 + self.clipping_epsilon)
                clipped_surrogate = torch.min(
                    sampling_ratio * normalised_advantages_sample,
                    clipped_ratio * normalised_advantages_sample)
                policy_loss = -torch.mean(clipped_surrogate)

                # Define value estimator loss
                state_values_sample = self.value_estimator(
                    states_sample).squeeze()
                value_estimator_loss = nn.MSELoss()(state_values_sample,
                                                    future_returns_sample)

                # Update value estimator
                value_estimator_optimiser.zero_grad()
                value_estimator_loss.backward()
                nn.utils.clip_grad_norm_(self.value_estimator.parameters(),
                                         0.75)
                value_estimator_optimiser.step()

                # Update policy
                policy_optimiser.zero_grad()
                policy_loss.backward()
                nn.utils.clip_grad_norm_(self.policy.parameters(), 0.75)
                policy_optimiser.step()
Esempio n. 27
0
    dis_optimizer = Optim('adam', 1e-3, lr_decay=0.5, max_weight_value=1.0)
    dis_optimizer.set_parameters(dis.parameters())
    # train_discriminator(dis, dis_optimizer, train_iter, gen, pretrain_acc, PRETRAIN_DISC_EPOCHS)

    # torch.save(dis.state_dict(), pretrained_dis_path)
    # dis.load_state_dict(torch.load(pretrained_dis_path))
    # ADVERSARIAL TRAINING
    pg_count=10000
    best_advbleu = 0

    pg_optimizer = Optim('myadam', 1e-3, max_grad_norm=5)
    pg_optimizer.set_parameters(gen.parameters())
    gen_optimizer.reset_learningrate(1e-3)
    dis_optimizer.reset_learningrate(1e-3)

    rollout = Rollout(gen, update_learning_rate)
    for epoch in range(ADV_TRAIN_EPOCHS):

        for i, data in enumerate(train_iter):
            tgt_data = data.target[0].permute(1, 0)  # batch_size X length
            src_data_wrap = data.source
            ans = data.answer[0]

            if CUDA:
                scr_data = data.source[0].to(device)
                scr_lengths = data.source[1].to(device)
                ans = ans.to(device)
                src_data_wrap = (scr_data, scr_lengths, ans)
                passage = src_data_wrap[0].permute(1, 0)

                tgt_data = tgt_data.to(device)
Esempio n. 28
0
    def __init__(self, training_scene, training_objects, config, arguments):

        self.config = config
        self.arguments = arguments

        self.training_scene = training_scene
        self.training_objects = training_objects

        self.use_gae = arguments.get('use_gae')
        self.num_epochs = arguments.get('num_epochs')
        self.num_episodes = arguments.get('num_episodes')
        self.num_iters = arguments.get('num_iters')
        self.gamma = arguments.get('gamma')
        self.lamb = arguments.get('lamb')
        self.lr = arguments.get('lr')
        self.joint_loss = arguments.get('joint_loss')
        self.ec = arguments.get('ec')
        self.vc = arguments.get('vc')
        self.max_grad_norm = arguments.get('max_gradient_norm')
        self.dropout = arguments.get('dropout')
        self.decay = arguments.get('decay')
        self.reuse = arguments.get('share_latent')
        self.gpu_fraction = arguments.get('gpu_fraction')

        self.rollouts = []
        if arguments['embed']:
            self.embeddings = pickle.load(
                open(config['embeddings_fasttext'], 'rb'))
            for obj in training_objects:
                self.rollouts.append(
                    Rollout(training_scene, obj, config, arguments,
                            self.embeddings[obj].tolist()))
        else:
            self.embeddings = np.identity(len(self.training_objects))
            for i, obj in enumerate(self.training_objects):
                self.rollouts.append(
                    Rollout(training_scene, obj, config, arguments,
                            self.embeddings[i].tolist()))

        self.env = AI2ThorDumpEnv(training_scene, training_objects[0], config,
                                  arguments)

        tf.reset_default_graph()

        self.PGNetwork = A2C(name='A2C',
                             state_size=self.env.features.shape[1],
                             action_size=self.env.action_space,
                             history_size=arguments['history_size'],
                             embedding_size=300 if arguments['embed'] else len(
                                 self.training_objects),
                             entropy_coeff=self.ec,
                             value_function_coeff=self.vc,
                             max_gradient_norm=self.max_grad_norm,
                             dropout=self.dropout,
                             joint_loss=self.joint_loss,
                             learning_rate=self.lr,
                             decay=self.decay,
                             reuse=bool(self.reuse))

        if self.decay:
            self.PGNetwork.set_lr_decay(
                self.lr, self.num_epochs * self.num_episodes * self.num_iters)

        print("\nInitialized network with {} trainable weights.".format(
            len(self.PGNetwork.find_trainable_variables('A2C', True))))

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=self.gpu_fraction)

        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        self.sess.run(tf.global_variables_initializer())

        self.saver = tf.train.Saver()

        timer = "{}_{}_{}".format(str(datetime.now()).replace(" ", "-").replace(".", "-").replace(":", "-"), \
               training_scene, "_".join(training_objects))
        self.log_folder = os.path.join(arguments.get('logging'), timer)
        self.writer = tf.summary.FileWriter(self.log_folder)

        self.timer = timer

        self.reward_logs = []
        self.success_logs = []
        self.redundant_logs = []

        test_name = training_scene
        for training_object in training_objects:
            self.reward_logs.append(
                tf.placeholder(tf.float32,
                               name="rewards_{}".format(training_object)))
            self.success_logs.append(
                tf.placeholder(tf.float32,
                               name="success_{}".format(training_object)))
            self.redundant_logs.append(
                tf.placeholder(tf.float32,
                               name="redundant_{}".format(training_object)))

            tf.summary.scalar(test_name + "/" + training_object + "/rewards",
                              self.reward_logs[-1])
            tf.summary.scalar(
                test_name + "/" + training_object + "/success_rate",
                self.success_logs[-1])
            tf.summary.scalar(
                test_name + "/" + training_object + "/redundants",
                self.redundant_logs[-1])

        self.write_op = tf.summary.merge_all()
Esempio n. 29
0
    def __init__(self,
                 model,
                 num_actions,
                 nenvs,
                 lr,
                 epsilon,
                 gamma=0.99,
                 lam=0.95,
                 lstm_unit=256,
                 value_factor=0.5,
                 entropy_factor=0.01,
                 time_horizon=128,
                 batch_size=32,
                 epoch=3,
                 grad_clip=40.0,
                 state_shape=[84, 84, 1],
                 phi=lambda s: s,
                 use_lstm=False,
                 continuous=False,
                 upper_bound=1.0,
                 name='ppo',
                 training=True):
        self.num_actions = num_actions
        self.gamma = gamma
        self.lam = lam
        self.lstm_unit = lstm_unit
        self.name = name
        self.state_shape = state_shape
        self.nenvs = nenvs
        self.lr = lr
        self.epsilon = epsilon
        self.time_horizon = time_horizon
        self.batch_size = batch_size
        self.epoch = epoch
        self.phi = phi 
        self.use_lstm = use_lstm
        self.continuous = continuous
        self.upper_bound = upper_bound
        self.episode_experience = []
        self.all_experience = []
        self.ep_count = 0

        self._act, self._train = build_train(
            model=model,
            num_actions=num_actions,
            lr=lr.get_variable(),
            epsilon=epsilon.get_variable(),
            nenvs=nenvs,
            step_size=batch_size,
            lstm_unit=lstm_unit,
            state_shape=state_shape,
            grad_clip=grad_clip,
            value_factor=value_factor,
            entropy_factor=entropy_factor,
            continuous=continuous,
            scope=name
        )

        self.initial_state = np.zeros((nenvs, lstm_unit*2), np.float32)
        self.rnn_state = self.initial_state

        self.state_tm1 = dict(obs=None, action=None, value=None,
                              log_probs=None, done=None, rnn_state=None)
        self.rollouts = [Rollout() for _ in range(nenvs)]
        self.t = 0
        self.training = training
Esempio n. 30
0
class MultitaskPolicy(object):
    def __init__(self, map_index, policies, writer, write_op, num_task,
                 num_iters, num_episode, num_epochs, gamma, lamb, plot_model,
                 save_model, save_name, share_exp, use_laser, use_gae,
                 noise_argmax):

        self.map_index = map_index
        self.PGNetwork = policies

        self.writer = writer
        self.write_op = write_op
        self.use_gae = use_gae

        self.num_task = num_task
        self.num_iters = num_iters
        self.num_epochs = num_epochs
        self.num_episode = num_episode

        self.gamma = gamma
        self.lamb = lamb
        self.save_name = save_name
        self.plot_model = plot_model
        self.save_model = save_model

        self.share_exp = share_exp
        self.noise_argmax = noise_argmax

        self.env = Terrain(self.map_index, use_laser)

        assert self.num_task <= self.env.num_task

        self.plot_figure = PlotFigure(self.save_name, self.env, self.num_task)

        self.rollout = Rollout(
            num_task=self.num_task,
            num_episode=self.num_episode,
            num_iters=self.num_iters,
            map_index=self.map_index,
            use_laser=use_laser,
            noise_argmax=self.noise_argmax,
        )

    def _discount_rewards(self, episode_rewards, episode_nexts, task,
                          current_value):
        discounted_episode_rewards = np.zeros_like(episode_rewards)
        next_value = 0.0
        if episode_rewards[-1] == 1:
            next_value = 0.0
        else:
            next_value = current_value[episode_nexts[-1][0],
                                       episode_nexts[-1][1], task]

        for i in reversed(range(len(episode_rewards))):
            next_value = episode_rewards[i] + self.gamma * next_value
            discounted_episode_rewards[i] = next_value

        return discounted_episode_rewards.tolist()

    def _GAE(self, episode_rewards, episode_states, episode_nexts, task,
             current_value):
        ep_GAE = np.zeros_like(episode_rewards)
        TD_error = np.zeros_like(episode_rewards)
        lamda = 0.96

        next_value = 0.0
        if episode_rewards[-1] == 1:
            next_value = 0.0
        else:
            next_value = current_value[episode_nexts[-1][0],
                                       episode_nexts[-1][1], task]

        for i in reversed(range(len(episode_rewards))):
            TD_error[i] = episode_rewards[
                i] + self.gamma * next_value - current_value[
                    episode_states[i][0], episode_states[i][1], task]
            next_value = current_value[episode_states[i][0],
                                       episode_states[i][1], task]

        ep_GAE[len(episode_rewards) - 1] = TD_error[len(episode_rewards) - 1]
        weight = self.gamma * lamda
        for i in reversed(range(len(episode_rewards) - 1)):
            ep_GAE[i] += TD_error[i] + weight * ep_GAE[i + 1]

        return ep_GAE.tolist()

    def _prepare_current_policy(self, sess, epoch):
        current_policy = {}

        for task in range(self.num_task):
            for (x, y) in self.env.state_space:
                state_index = self.env.state_to_index[y][x]
                logit, pi = sess.run(
                    [
                        self.PGNetwork[task].actor.logits,
                        self.PGNetwork[task].actor.pi
                    ],
                    feed_dict={
                        self.PGNetwork[task].actor.inputs:
                        [self.env.cv_state_onehot[state_index]],
                    })

                current_policy[x, y, task, 0] = logit.ravel().tolist()
                current_policy[x, y, task, 1] = pi.ravel().tolist()

        if (epoch + 1) % self.plot_model == 0:
            self.plot_figure.plot(current_policy, epoch + 1)

        return current_policy

    def _prepare_current_values(self, sess, epoch):
        current_values = {}

        for task in range(self.num_task):
            for (x, y) in self.env.state_space:
                state_index = self.env.state_to_index[y][x]
                v = sess.run(self.PGNetwork[task].critic.value,
                             feed_dict={
                                 self.PGNetwork[task].critic.inputs:
                                 [self.env.cv_state_onehot[state_index]],
                             })

                current_values[x, y, task] = v.ravel().tolist()[0]

        # if (epoch+1) % self.plot_model == 0 or epoch == 0:
        # 	self.plot_figure.plot(current_values, epoch + 1)

        return current_values

    def _prepare_current_neglog(self, sess, epoch):
        current_neglog = {}

        for task in range(self.num_task):
            for (x, y) in self.env.state_space:
                for action in range(self.env.action_size):
                    state_index = self.env.state_to_index[y][x]
                    neglog = sess.run(
                        self.PGNetwork[task].actor.neg_log_prob,
                        feed_dict={
                            self.PGNetwork[task].actor.inputs:
                            [self.env.cv_state_onehot[state_index]],
                            self.PGNetwork[task].actor.actions:
                            [self.env.cv_action_onehot[action]]
                        })
                    current_neglog[x, y, action,
                                   task] = neglog.ravel().tolist()[0]

        return current_neglog

    def _make_batch(self, sess, epoch):

        current_policy = self._prepare_current_policy(sess, epoch)
        current_values = self._prepare_current_values(sess, epoch)
        '''
		states = [
		    task1		[[---episode_1---],...,[---episode_n---]],
		    task2		[[---episode_1---],...,[---episode_n---]],
		   .
		   .
			task_k      [[---episode_1---],...,[---episode_n---]],
		]
		same as actions, tasks, rewards, values, dones
		
		last_values = [
			task1		[---episode_1---, ..., ---episode_n---],
		    task2		[---episode_1---, ..., ---episode_n---],
		   .
		   .
			task_k      [---episode_1---, ..., ---episode_n---],	
		]
		'''
        states, tasks, actions, rewards, next_states, redundant_steps = self.rollout.rollout_batch(
            current_policy, epoch)

        observations = [[] for i in range(self.num_task)]
        converted_actions = [[] for i in range(self.num_task)]
        logits = [[] for i in range(self.num_task)]

        for task_idx, task_states in enumerate(states):
            for ep_idx, ep_states in enumerate(task_states):
                observations[task_idx] += [
                    self.env.cv_state_onehot[self.env.state_to_index[s[1]][
                        s[0]]] for s in ep_states
                ]
                converted_actions[task_idx] += [
                    self.env.cv_action_onehot[a]
                    for a in actions[task_idx][ep_idx]
                ]
                logits[task_idx] += [
                    current_policy[s[0], s[1], task_idx, 0] for s in ep_states
                ]

        returns = [[] for i in range(self.num_task)]
        advantages = [[] for i in range(self.num_task)]

        if not self.use_gae:

            for task_idx in range(self.num_task):
                for ep_idx, (ep_rewards, ep_states,
                             ep_next_states) in enumerate(
                                 zip(rewards[task_idx], states[task_idx],
                                     next_states[task_idx])):
                    ep_dones = list(np.zeros_like(ep_rewards))

                    ep_returns = self._discount_rewards(
                        ep_rewards, ep_next_states, task_idx, current_values)
                    returns[task_idx] += ep_returns

                    ep_values = [
                        current_values[s[0], s[1], task_idx] for s in ep_states
                    ]

                    # Here we calculate advantage A(s,a) = R + yV(s') - V(s)
                    # rewards = R + yV(s')
                    advantages[task_idx] += list(
                        (np.array(ep_returns) - np.array(ep_values)).astype(
                            np.float32))

                assert len(returns[task_idx]) == len(advantages[task_idx])
        else:

            for task_idx in range(self.num_task):
                for ep_idx, (ep_rewards, ep_states,
                             ep_next_states) in enumerate(
                                 zip(rewards[task_idx], states[task_idx],
                                     next_states[task_idx])):
                    ep_dones = list(np.zeros_like(ep_rewards))

                    returns[task_idx] += self._discount_rewards(
                        ep_rewards, ep_next_states, task_idx, current_values)
                    advantages[task_idx] += self._GAE(ep_rewards, ep_states,
                                                      ep_next_states, task_idx,
                                                      current_values)

                assert len(returns[task_idx]) == len(advantages[task_idx])

        for task_idx in range(self.num_task):
            states[task_idx] = np.concatenate(states[task_idx])
            actions[task_idx] = np.concatenate(actions[task_idx])
            redundant_steps[task_idx] = np.mean(redundant_steps[task_idx])

        share_observations = [[] for _ in range(self.num_task)]
        share_actions = [[] for _ in range(self.num_task)]
        share_advantages = [[] for _ in range(self.num_task)]
        share_logits = [[] for _ in range(self.num_task)]

        if self.share_exp:
            assert self.num_task > 1

            for task_idx in range(self.num_task):
                for idx, s in enumerate(states[task_idx]):

                    if self.env.MAP[s[1]][s[0]] == 2:

                        act = actions[task_idx][idx]

                        # and share with other tasks
                        for other_task in range(self.num_task):
                            if other_task == task_idx:
                                continue

                            share_observations[other_task].append(
                                self.env.cv_state_onehot[
                                    self.env.state_to_index[s[1]][s[0]]])
                            share_actions[other_task].append(
                                self.env.cv_action_onehot[act])
                            share_advantages[other_task].append(
                                advantages[task_idx][idx])
                            share_logits[other_task].append(
                                current_policy[s[0], s[1], task_idx, 0])

        return observations, converted_actions, returns, advantages, logits, rewards, share_observations, share_actions, share_advantages, share_logits, redundant_steps

    def train(self, sess, saver):
        total_samples = {}

        for epoch in range(self.num_epochs):
            # sys.stdout.flush()

            # ROLLOUT SAMPLE
            #---------------------------------------------------------------------------------------------------------------------#
            mb_states, mb_actions, mb_returns, mb_advantages, mb_logits, rewards, mbshare_states, mbshare_actions, mbshare_advantages, mbshare_logits, mb_redundant_steps = self._make_batch(
                sess, epoch)
            #---------------------------------------------------------------------------------------------------------------------#

            print('epoch {}/{}'.format(epoch, self.num_epochs),
                  end='\r',
                  flush=True)
            # UPDATE NETWORK
            #---------------------------------------------------------------------------------------------------------------------#
            sum_dict = {}
            for task_idx in range(self.num_task):
                assert len(mb_states[task_idx]) == len(
                    mb_actions[task_idx]) == len(mb_returns[task_idx]) == len(
                        mb_advantages[task_idx]) == len(mb_logits[task_idx])
                assert len(mbshare_states[task_idx]) == len(
                    mbshare_actions[task_idx]) == len(
                        mbshare_advantages[task_idx]) == len(
                            mbshare_logits[task_idx])

                if not self.share_exp:
                    policy_loss, value_loss, ratio = self.PGNetwork[
                        task_idx].learn(sess, mb_states[task_idx],
                                        mb_actions[task_idx],
                                        mb_returns[task_idx],
                                        mb_advantages[task_idx],
                                        mb_logits[task_idx])
                else:
                    value_loss = self.PGNetwork[task_idx].learn_critic(
                        sess, mb_states[task_idx], mb_returns[task_idx])

                    policy_loss, ratio = self.PGNetwork[task_idx].learn_actor(
                        sess, mb_states[task_idx] + mbshare_states[task_idx],
                        mb_actions[task_idx] + mbshare_actions[task_idx],
                        mb_advantages[task_idx] + mbshare_advantages[task_idx],
                        mb_logits[task_idx] + mbshare_logits[task_idx])

                sum_dict[self.PGNetwork[task_idx].mean_reward] = np.sum(
                    np.concatenate(rewards[task_idx])) / len(rewards[task_idx])
                sum_dict[self.PGNetwork[task_idx].
                         mean_redundant] = mb_redundant_steps[task_idx]
                sum_dict[self.PGNetwork[task_idx].ratio_ph] = np.mean(ratio)

                if task_idx not in total_samples:
                    total_samples[task_idx] = 0

                total_samples[task_idx] += len(
                    list(np.concatenate(rewards[task_idx])))

            #---------------------------------------------------------------------------------------------------------------------#

            # WRITE TF SUMMARIES
            #---------------------------------------------------------------------------------------------------------------------#
            summary = sess.run(self.write_op, feed_dict=sum_dict)

            self.writer.add_summary(summary, total_samples[0])
            self.writer.flush()
            #---------------------------------------------------------------------------------------------------------------------#

            # SAVE MODEL
            #---------------------------------------------------------------------------------------------------------------------#
            # if epoch % self.save_model == 0:
            # 	saver.save(sess, 'checkpoints/' + self.save_name + '.ckpt')
            #---------------------------------------------------------------------------------------------------------------------#
Esempio n. 31
0
def main(pretrain_checkpoint_dir,
         train_summary_writer,
         vocab: Vocab,
         dataloader: DataLoader,
         batch_size: int = 64,
         embedding_dim: int = 256,
         seq_length: int = 3000,
         gen_seq_len: int = 3000,
         gen_rnn_units: int = 1024,
         disc_rnn_units: int = 1024,
         epochs: int = 40000,
         pretrain_epochs: int = 4000,
         learning_rate: float = 1e-4,
         rollout_num: int = 2,
         gen_pretrain: bool = False,
         disc_pretrain: bool = False,
         load_gen_weights: bool = False,
         load_disc_weights: bool = False,
         save_gen_weights: bool = True,
         save_disc_weights: bool = True,
         disc_steps: int = 3):
    gen = Generator(dataloader=dataloader,
                    vocab=vocab,
                    batch_size=batch_size,
                    embedding_dim=embedding_dim,
                    seq_length=seq_length,
                    checkpoint_dir=pretrain_checkpoint_dir,
                    rnn_units=gen_rnn_units,
                    start_token=0,
                    learning_rate=learning_rate)
    if load_gen_weights:
        gen.load_weights()
    if gen_pretrain:
        gen_pre_trainer = GenPretrainer(gen,
                                        dataloader=dataloader,
                                        vocab=vocab,
                                        pretrain_epochs=pretrain_epochs,
                                        tb_writer=train_summary_writer,
                                        learning_rate=learning_rate)
        print('Start pre-training generator...')
        gen_pre_trainer.pretrain(gen_seq_len=gen_seq_len,
                                 save_weights=save_gen_weights)

    disc = Discriminator(vocab_size=vocab.vocab_size,
                         embedding_dim=embedding_dim,
                         rnn_units=disc_rnn_units,
                         batch_size=batch_size,
                         checkpoint_dir=pretrain_checkpoint_dir,
                         learning_rate=learning_rate)
    if load_disc_weights:
        disc.load_weights()
    if disc_pretrain:
        disc_pre_trainer = DiscPretrainer(disc,
                                          gen,
                                          dataloader=dataloader,
                                          vocab=vocab,
                                          pretrain_epochs=pretrain_epochs,
                                          tb_writer=train_summary_writer,
                                          learning_rate=learning_rate)
        print('Start pre-training discriminator...')
        disc_pre_trainer.pretrain(save_disc_weights)
    rollout = Rollout(generator=gen,
                      discriminator=disc,
                      vocab=vocab,
                      batch_size=batch_size,
                      seq_length=seq_length,
                      rollout_num=rollout_num)

    with tqdm(desc='Epoch: ', total=epochs, dynamic_ncols=True) as pbar:
        for epoch in range(epochs):
            fake_samples = gen.generate()
            rewards = rollout.get_reward(samples=fake_samples)
            gen_loss = gen.train_step(fake_samples, rewards)
            real_samples, _ = dataloader.get_batch(shuffle=shuffle,
                                                   seq_length=seq_length,
                                                   batch_size=batch_size,
                                                   training=True)
            disc_loss = 0
            for i in range(disc_steps):
                disc_loss += disc.train_step(fake_samples,
                                             real_samples) / disc_steps

            with train_summary_writer.as_default():
                tf.summary.scalar('gen_train_loss', gen_loss, step=epoch)
                tf.summary.scalar('disc_train_loss', disc_loss, step=epoch)
                tf.summary.scalar('total_train_loss',
                                  disc_loss + gen_loss,
                                  step=epoch)

            pbar.set_postfix(gen_train_loss=tf.reduce_mean(gen_loss),
                             disc_train_loss=tf.reduce_mean(disc_loss),
                             total_train_loss=tf.reduce_mean(gen_loss +
                                                             disc_loss))

            if (epoch + 1) % 5 == 0 or (epoch + 1) == 1:
                print('保存weights...')
                # 保存weights
                gen.model.save_weights(gen.checkpoint_prefix)
                disc.model.save_weights(disc.checkpoint_prefix)
                # gen.model.save('gen.h5')
                # disc.model.save('disc.h5')

                # 测试 disc
                fake_samples = gen.generate(gen_seq_len)
                real_samples = dataloader.get_batch(shuffle=shuffle,
                                                    seq_length=gen_seq_len,
                                                    batch_size=batch_size,
                                                    training=False)
                disc_loss = disc.test_step(fake_samples, real_samples)

                # 测试 gen
                gen_loss = gen.test_step()

                # 得到bleu_score
                # bleu_score = get_bleu_score(true_seqs=real_samples, genned_seqs=fake_samples)
                genned_sentences = vocab.extract_seqs(fake_samples)
                # print(genned_sentences)
                # print(vocab.idx2char[fake_samples[0]])

                # 记录 test losses
                with train_summary_writer.as_default():
                    tf.summary.scalar('disc_test_loss',
                                      tf.reduce_mean(disc_loss),
                                      step=epoch)
                    tf.summary.scalar('gen_test_loss',
                                      tf.reduce_mean(gen_loss),
                                      step=epoch)
                    # tf.summary.scalar('bleu_score', tf.reduce_mean(bleu_score), step=epoch + gen_pretrain * pretrain_epochs)

            pbar.update()
Esempio n. 32
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE)  # For testing
    vocab_size = 2000
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator(vocab_size, EMB_DIM, HIDDEN_DIM, 1, START_TOKEN,
                          SEQ_LENGTH).to(device)
    target_lstm = Generator(vocab_size,
                            EMB_DIM,
                            HIDDEN_DIM,
                            1,
                            START_TOKEN,
                            SEQ_LENGTH,
                            oracle=True).to(device)
    discriminator = Discriminator(vocab_size, dis_embedding_dim,
                                  dis_filter_sizes, dis_num_filters,
                                  dis_dropout).to(device)

    generate_samples(target_lstm, BATCH_SIZE, generated_num, positive_file)
    gen_data_loader.create_batches(positive_file)

    pre_gen_opt = torch.optim.Adam(generator.parameters(), 1e-2)
    adv_gen_opt = torch.optim.Adam(generator.parameters(), 1e-2)
    dis_opt = torch.optim.Adam(discriminator.parameters(), 1e-4)
    dis_criterion = nn.NLLLoss()

    log = open('save/experiment-log.txt', 'w')
    print('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        loss = pre_train_epoch(generator, pre_gen_opt, gen_data_loader)
        if (epoch + 1) % 5 == 0:
            generate_samples(generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(target_lstm, likelihood_data_loader)
            print('pre-train epoch ', epoch + 1, '\tnll:\t', test_loss)
            buffer = 'epoch:\t' + str(epoch +
                                      1) + '\tnll:\t' + str(test_loss) + '\n'
            log.write(buffer)
    print('Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    for e in range(50):
        generate_samples(generator, BATCH_SIZE, generated_num, negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        d_total_loss = []
        for _ in range(3):
            dis_data_loader.reset_pointer()
            total_loss = []
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                x_batch = x_batch.to(device)
                y_batch = y_batch.to(device)
                dis_output = discriminator(x_batch.detach())
                d_loss = dis_criterion(dis_output, y_batch.detach())
                dis_opt.zero_grad()
                d_loss.backward()
                dis_opt.step()
                total_loss.append(d_loss.data.cpu().numpy())
            d_total_loss.append(np.mean(total_loss))
        if (e + 1) % 5 == 0:
            buffer = 'Epoch [{}], discriminator loss [{:.4f}]\n'.format(
                e + 1, np.mean(d_total_loss))
            print(buffer)
            log.write(buffer)

    rollout = Rollout(generator, 0.8)
    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    gan_loss = GANLoss()
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        discriminator.eval()
        for it in range(1):
            samples, _ = generator.sample(num_samples=BATCH_SIZE)
            rewards = rollout.get_reward(samples, 16, discriminator)
            prob = generator(samples.detach())
            adv_loss = gan_loss(prob, samples.detach(), rewards.detach())
            adv_gen_opt.zero_grad()
            adv_loss.backward()
            nn.utils.clip_grad_norm_(generator.parameters(), 5.0)
            adv_gen_opt.step()

        # Test
        if (total_batch + 1) % 5 == 0:
            generate_samples(generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(target_lstm, likelihood_data_loader)
            self_bleu_score = self_bleu(generator)
            buffer = 'epoch:\t' + str(total_batch + 1) + '\tnll:\t' + str(
                test_loss) + '\tSelf Bleu:\t' + str(self_bleu_score) + '\n'
            print(buffer)
            log.write(buffer)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        discriminator.train()
        for _ in range(5):
            generate_samples(generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)
            d_total_loss = []
            for _ in range(3):
                dis_data_loader.reset_pointer()
                total_loss = []
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    x_batch = x_batch.to(device)
                    y_batch = y_batch.to(device)
                    dis_output = discriminator(x_batch.detach())
                    d_loss = dis_criterion(dis_output, y_batch.detach())
                    dis_opt.zero_grad()
                    d_loss.backward()
                    dis_opt.step()
                    total_loss.append(d_loss.data.cpu().numpy())
                d_total_loss.append(np.mean(total_loss))
            if (total_batch + 1) % 5 == 0:
                buffer = 'Epoch [{}], discriminator loss [{:.4f}]\n'.format(
                    total_batch + 1, np.mean(d_total_loss))
                print(buffer)
                log.write(buffer)
    log.close()
Esempio n. 33
0
class Agent:
    def __init__(self,
                 actions,
                 optimizer,
                 convs,
                 fcs,
                 padding,
                 lstm,
                 gamma=0.99,
                 lstm_unit=256,
                 time_horizon=5,
                 policy_factor=1.0,
                 value_factor=0.5,
                 entropy_factor=0.01,
                 grad_clip=40.0,
                 state_shape=[84, 84, 1],
                 buffer_size=2e3,
                 rp_frame=3,
                 phi=lambda s: s,
                 name='global'):
        self.actions = actions
        self.gamma = gamma
        self.name = name
        self.time_horizon = time_horizon
        self.state_shape = state_shape
        self.rp_frame = rp_frame
        self.phi = phi

        self._act,\
        self._train,\
        self._update_local = build_graph.build_train(
            convs=convs,
            fcs=fcs,
            padding=padding,
            lstm=lstm,
            num_actions=len(actions),
            optimizer=optimizer,
            lstm_unit=lstm_unit,
            state_shape=state_shape,
            grad_clip=grad_clip,
            policy_factor=policy_factor,
            value_factor=value_factor,
            entropy_factor=entropy_factor,
            rp_frame=rp_frame,
            scope=name
        )

        # rnn state variables
        self.initial_state = np.zeros((1, lstm_unit), np.float32)
        self.rnn_state0 = self.initial_state
        self.rnn_state1 = self.initial_state

        # last state variables
        self.zero_state = np.zeros(state_shape, dtype=np.float32)
        self.initial_last_obs = [self.zero_state for _ in range(rp_frame)]
        self.last_obs = deque(self.initial_last_obs, maxlen=rp_frame)
        self.last_action = deque([0, 0], maxlen=2)
        self.value_tm1 = None
        self.reward_tm1 = 0.0

        # buffers
        self.rollout = Rollout()
        self.buffer = ReplayBuffer(capacity=buffer_size)

        self.t = 0
        self.t_in_episode = 0

    def train(self, bootstrap_value):
        # prepare A3C update
        obs_t = np.array(self.rollout.obs_t, dtype=np.float32)
        actions_t = np.array(self.rollout.actions_t, dtype=np.uint8)
        actions_tm1 = np.array(self.rollout.actions_tm1, dtype=np.uint8)
        rewards_tp1 = self.rollout.rewards_tp1
        rewards_t = self.rollout.rewards_t
        values_t = self.rollout.values_t
        state_t0 = self.rollout.states_t[0][0]
        state_t1 = self.rollout.states_t[0][1]

        # compute returns
        R = bootstrap_value
        returns_t = []
        for reward in reversed(rewards_tp1):
            R = reward + self.gamma * R
            returns_t.append(R)
        returns_t = np.array(list(reversed(returns_t)))
        adv_t = returns_t - values_t

        # prepare reward prediction update
        rp_obs, rp_reward_tp1 = self.buffer.sample_rp()

        # prepare value function replay update
        vr_obs_t,\
        vr_actions_tm1,\
        vr_rewards_t,\
        is_terminal = self.buffer.sample_vr(self.time_horizon)
        _, vr_values_t, _ = self._act(vr_obs_t, vr_actions_tm1, vr_rewards_t,
                                      self.initial_state, self.initial_state)
        vr_values_t = np.reshape(vr_values_t, [-1])
        if is_terminal:
            vr_bootstrap_value = 0.0
        else:
            vr_bootstrap_value = vr_values_t[-1]

        # compute returns for value prediction
        R = vr_bootstrap_value
        vr_returns_t = []
        for reward in reversed(vr_rewards_t[:-1]):
            R = reward + self.gamma * R
            vr_returns_t.append(R)
        vr_returns_t = np.array(list(reversed(vr_returns_t)))

        # update
        loss = self._train(
            obs_t=obs_t,
            rnn_state0=state_t0,
            rnn_state1=state_t1,
            actions_t=actions_t,
            rewards_t=rewards_t,
            actions_tm1=actions_tm1,
            returns_t=returns_t,
            advantages_t=adv_t,
            rp_obs=rp_obs,
            rp_reward_tp1=rp_reward_tp1,
            vr_obs_t=vr_obs_t[:-1],
            vr_actions_tm1=vr_actions_tm1[:-1],
            vr_rewards_t=vr_rewards_t[:-1],
            vr_returns_t=vr_returns_t
        )
        self._update_local()
        return loss

    def act(self, obs_t, reward_t, training=True):
        # change state shape to WHC
        obs_t = self.phi(obs_t)
        # last transitions
        action_tm2, action_tm1 = self.last_action
        obs_tm1 = self.last_obs[-1]
        # take next action
        prob, value, rnn_state = self._act(
            obs_t=[obs_t],
            actions_tm1=[action_tm1],
            rewards_t=[reward_t],
            rnn_state0=self.rnn_state0,
            rnn_state1=self.rnn_state1
        )
        action_t = np.random.choice(range(len(self.actions)), p=prob[0])

        if training:
            if len(self.rollout.obs_t) == self.time_horizon:
                self.train(self.value_tm1)
                self.rollout.flush()

            if self.t_in_episode > 0:
                # add transition to buffer for A3C update
                self.rollout.add(
                    obs_t=obs_tm1,
                    reward_tp1=reward_t,
                    reward_t=self.reward_tm1,
                    action_t=action_tm1,
                    action_tm1=action_tm2,
                    value_t=self.value_tm1,
                    terminal_tp1=False,
                    state_t=[self.rnn_state0, self.rnn_state1]
                )
                # add transition to buffer for auxiliary update
                self.buffer.add(
                    obs_t=list(self.last_obs),
                    action_tm1=action_tm2,
                    reward_t=self.reward_tm1,
                    action_t=action_tm1,
                    reward_tp1=reward_t,
                    obs_tp1=obs_t,
                    terminal=False
                )

        self.t += 1
        self.t_in_episode += 1
        self.rnn_state0, self.rnn_state1 = rnn_state
        self.last_obs.append(obs_t)
        self.last_action.append(action_t)
        self.value_tm1 = value[0][0]
        self.reward_tm1 = reward_t
        return self.actions[action_t]

    def stop_episode(self, obs_t, reward_t, training=True):
        # change state shape to WHC
        obs_t = self.phi(obs_t)
        # last transitions
        action_tm2, action_tm1 = self.last_action
        obs_tm1 = self.last_obs[-1]
        if training:
            # add transition for A3C update
            self.rollout.add(
                obs_t=obs_tm1,
                action_t=action_tm1,
                reward_t=self.reward_tm1,
                reward_tp1=reward_t,
                action_tm1=action_tm2,
                value_t=self.value_tm1,
                state_t=[self.rnn_state0, self.rnn_state1],
                terminal_tp1=True
            )
            # add transition for auxiliary update
            self.buffer.add(
                obs_t=list(self.last_obs),
                action_tm1=action_tm2,
                reward_t=self.reward_tm1,
                action_t=action_tm1,
                reward_tp1=reward_t,
                obs_tp1=obs_t,
                terminal=True
            )
            self.train(0.0)
            self.rollout.flush()
        self.rnn_state0 = self.initial_state
        self.rnn_state1 = self.initial_state
        self.last_obs = deque(self.initial_last_obs, maxlen=self.rp_frame)
        self.last_action = deque([0, 0], maxlen=2)
        self.value_tm1 = None
        self.reward_tm1 = 0.0
        self.t_in_episode = 0