Exemplo n.º 1
0
 def setUp(self) -> None:
     self.queue = queue()
Exemplo n.º 2
0
    def train_controller(self, model, optimizer, device, train_loader,
                         valid_loader, epoch, momentum, entropy_weight,
                         child_retrain_epoch, child_retrain_interval):

        #TODO: entropy_weight
        model.train()

        step = 0

        prev_runs = torch.zeros([5])  # to store the val_acc of prev epochs

        for epoch_idx in range(epoch):
            loss = torch.FloatTensor([0])

            epoch_valacc = torch.zeros(self.num_of_children)
            epoch_childs = []

            for child_idx in range(self.num_of_children):
                # images, labels.cuda()

                step += 1

                model()  # forward pass without input
                sampled_architecture = model.sampled_architecture
                sampled_entropies = model.sampled_entropies.detach()
                sampled_logprobs = model.sampled_logprobs

                # get the acc of a single child

                #make child
                conf = self.make_enas_config(sampled_architecture)
                epoch_childs.append(conf)

                print("CONF:", conf)
                if self.isShared:
                    child = self.child.to(device)
                else:
                    child = SharedEnasChild(
                        conf,
                        self.num_layers,
                        self.learning_rate_child,
                        momentum,
                        num_classes=self.num_classes,
                        out_filters=self.out_filters,
                        input_shape=self.input_shape,
                        input_channels=self.input_channels).to(device)

#               self.logger.info("train_controller, epoch/child : ", epoch_idx, child_idx, " child : ", conf) # logging error

#Train child
                self.train_child(child, conf, device, train_loader,
                                 self.epoch_child, epoch_idx, child_idx)

                #Test child
                validation_accuracy, validation_loss = self.test_child(
                    child, conf, device, valid_loader)

                reward = torch.tensor(validation_accuracy).detach()
                reward += sampled_entropies * entropy_weight

                # calculate advantage with baseline (moving avg)
                baseline = prev_runs.mean(
                )  # substract baseline to reduce variance in rewards

                reward = reward - baseline

                #               self.logger.info(prev_runs, baseline, reward) # logging error

                loss -= sampled_logprobs * reward
                epoch_valacc[child_idx] = validation_accuracy

                # logging to tensorboard
                self.writer.add_scalar("loss", loss.item(), global_step=step)
                self.writer.add_scalar("reward", reward, global_step=step)
                self.writer.add_scalar("valid_acc",
                                       validation_accuracy,
                                       global_step=step)
                self.writer.add_scalar("valid_loss",
                                       validation_loss,
                                       global_step=step)
                self.writer.add_scalar("sampled_entropies",
                                       sampled_entropies,
                                       global_step=step)
                self.writer.add_scalar("sampled_logprobs",
                                       sampled_logprobs,
                                       global_step=step)

            best_child_idx = torch.argmax(epoch_valacc)
            best_child_conf = epoch_childs[best_child_idx]

            message = " best valacc" + str(epoch_valacc[best_child_idx].item()) \
                      + ' - config: ' + str(best_child_conf)

            self.writer.add_text("best child", message, global_step=epoch_idx)

            if epoch_idx % child_retrain_interval == 0:
                retrained_valacc, retrained_loss = self.retrain(
                    best_child_conf, device, train_loader, valid_loader,
                    child_retrain_epoch, epoch_idx)
                print("current best childs: ", self.bestchilds.bestchilds)
                self.save(epoch_idx)

            if epoch_idx != 0:
                # trainig:
                loss.backward(
                    retain_graph=True
                )  # retrain_graph: keep the gradients, idk if we need this but tdvries does

                # to normalize gradients : grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), gradBound) #normalize gradient
                optimizer.step()
                model.zero_grad()

                # self.writer.add_histogram("sampled_branches", model.sampled_architecture, global_step=epoch_idx)
                # self.writer.add_histogram("sampled_connections", model.sampled_architecture[1], global_step=epoch_idx)
                self.writer.add_scalar("epoch_loss",
                                       loss.item(),
                                       global_step=epoch_idx)
                self.writer.add_scalar("epoch mean validation acc.",
                                       epoch_valacc.mean(),
                                       global_step=epoch_idx)

                #self.writer.add_graph(child) #ERROR:  TracedModules don't support parameter sharing between modules

            prev_runs = queue(prev_runs, epoch_valacc.mean())

        return prev_runs
Exemplo n.º 3
0
 def test_append_overflow(self):
     q = queue(0)
     self.assertRaises(OverflowError, q.append, 1)
Exemplo n.º 4
0
def main(tick, config, q):
    # Disband 'empty' factions with no colonies
    def disband(_id):
        r = cp.query(payload="SELECT *\
            FROM massive\
            WHERE object == 'colony' && faction == '" + _id + "'\
            LIMIT 0, 0")
        if int(r['hits']) == 0:
            r = cp.query(payload="DELETE massive['" + _id + "']")['results'][0]['_id']
            return 'factions: disband faction ' + _id
        else:
            return 'factions: keep faction ' + _id

    if tick%config['cleanFactions'] == 0:
        factions = cp.query(payload="\
            SELECT _id\
            FROM massive\
            WHERE object == 'faction'\
            LIMIT 0,9999")
        if 'results' in factions:
            for faction in factions['results']:
                utils.queue(disband, faction['_id'])

    # Abandon depopulation colonies
    def abandon(_id):
        r = cp.query(payload="\
            UPDATE massive['" + _id + "']\
            SET faction = null")['results'][0]['_id']
        return 'factions: abandon colony ' + r

    colonies = cp.query(payload="\
        SELECT _id, faction\
        FROM massive\
        WHERE object == 'colony' && population == 0 && faction")
    if 'results' in colonies:
        for colony in colonies['results']:
            utils.queue(abandon,colony['_id'])

    # Spawn new factions with initial colony
    count = int(cp.query(payload="SELECT * FROM massive WHERE object == 'faction' LIMIT 0, 0")['hits'])
    want = int(utils.dist_skewedLeft(config['factions']))
    if want > count:
        planet = cp.query(payload="\
            SELECT _id, size, habitability, richness, materials\
            FROM massive\
            WHERE\
                object == 'planet' &&\
                type != 'Gas' &&\
                system_coords.radius < "+str(int(config['optimalDistance'][0]*config['optimalDistance'][1]))+" &&\
                system_coords.radius > "+str(int(config['optimalDistance'][0]/config['optimalDistance'][1]))+"\
            ORDER BY Math.abs("+str(config['optimalDistance'][0])+" - system_coords.radius) * Math.random()\
            LIMIT 0, 1"
            )["results"][0]
        occupied = cp.query(payload="\
            SELECT SUM(population+industry) as size\
            FROM massive\
            WHERE object == 'colony' && anchor == '"+planet['_id']+"'\
            GROUP BY anchor")
        if int(occupied['hits']) == 0 or (planet['size'] >= occupied['results'][0]['size'] + 2):
            utils.queue(cp.put, payload={
                'object': 'faction',
                'pref': {  # 0 prefers first, 1 prefers last, 0.5 is indifferent.
                    'population_industry': utils.dist_flat(config['population_industry']),
                    'pacific_militaristic': utils.dist_flat(config['pacific_militaristic']),
                    'defence_attack': utils.dist_flat(config['defence_attack']),
                    'growth_expand': utils.dist_flat(config['growth_expand']),
                }},
                params='[faction'+str(tick)+']',
                msg='factions: spawn faction '+str(tick)
            )
            utils.queue(cp.put, payload={
                'object': 'colony',
                'faction': 'faction'+str(tick),
                'anchor': planet['_id'],
                'goods': re.search("(\w*)p", planet['_id']).group(1),
                'untilJoins': {
                    'habitability': planet['habitability'],
                    'richness': planet['richness'],
                    'materials': planet['materials'],
                },
                'population': 1,
                'industry': 1,
                'storage': {
                    'goods': {
                        re.search("(\w*)p", planet['_id']).group(1): 100,
                    },
                    'solids': 100,  # Upkeep for industry
                    'metals': 0,  # For structure of skips
                    'isotopes': 0,  # For guns of ships
                    'ammo': 0  # For guns to shoot
                }},
                msg='factions: spawn colony'
            )

    return 'done'
Exemplo n.º 5
0
def main(tick, config, q):
    # How many systems are there?
    count = int(cp.query(payload="SELECT * FROM massive WHERE object == 'system' LIMIT 0, 0")["hits"])

    # How many we want?
    want = int(utils.dist_skewedLeft(config['systems']))

    if want == count:
        pass
    elif want >= count:
        # Generate up to X more systems per tick
        pre = '[t' + str(tick)
        for system in range(0, min(config['systemsVolatility'], want - count)):
            urly = 'y' + str(system)
            stars = []
            for star in range(0, int(utils.dist_skewedLeft(config['stars']))):
                urls = urly + 's' + str(star)
                coordss = rndCoords(config['starDistance'], config['zFlatness'])
                planets = []
                for planet in range(0, int(utils.dist_skewedLeft(config['planets']))):
                    urlp = urls + 'p' + str(planet)
                    coordsp = rndCoords(config['planetDistance'], config['zFlatness'], coordss)
                    moons = []
                    for moon in range(0, int(utils.dist_skewedLeft(config['moons']))):
                        urlm = urlp + 'm' + str(moon)
                        utils.queue(cp.put, payload={
                                "object": "moon",
                                "system_coords": rndCoords(config['moonDistance'], config['zFlatness'], coordsp),
                                'type': ['Gas', 'Ice', 'Rock', 'Iron', 'Mix'][utils.choose_weighted(config['moonType'])],
                                'size': round(utils.dist_skewedLeft(config['moonSize']), 0),
                                'habitability': round(utils.dist_skewedLeft(config['moonHabitability']), 1),
                                'richness': round(utils.dist_skewedLeft(config['moonRichness']), 1),
                                'materials': rndMaterials(config['moonWeightSolidsOther'], config['moonWeightMetalsIsotopes']),
                            },
                            params=pre+urlm+"]",
                            msg='      Moon ' + pre+urlm+"]"
                        )
                    utils.queue(cp.put, payload={
                            "object": "planet",
                            "system_coords": rndCoords(config['planetDistance'], config['zFlatness']),
                            'type': ['Gas', 'Ice', 'Rock', 'Iron', 'Mix'][utils.choose_weighted(config['planetType'])],
                            'size': round(utils.dist_skewedLeft(config['planetSize']), 0),
                            'habitability': round(utils.dist_skewedLeft(config['planetHabitability']), 1),
                            'richness': round(utils.dist_skewedLeft(config['planetRichness']), 1),
                            'materials': rndMaterials(config['planetWeightSolidsOther'], config['planetWeightMetalsIsotopes']),
                        },
                        params=pre+urlp+"]",
                        msg='    Planet ' + pre+urlp+"]"
                    )
                utils.queue(cp.put, payload={
                        "id": star,
                        "object": "star",
                        "system_coords": rndCoords(config['starDistance'], config['zFlatness']),
                    },
                    params=pre+urls+"]",
                    msg='  Star ' + pre+urls+"]"
                )
            utils.queue(cp.put, payload={
                "object": "system",
                "universe_coords": {
                    "x": random.randrange(config['x'][0],config['x'][1]),
                    "y": random.randrange(config['y'][0],config['y'][1]),
                    "z": random.randrange(config['z'][0],config['z'][1]),
                },},
                params=pre+urly+"]",
                msg='System ' + pre+urly+"]"
            )
    elif want <= count:
        pass  # TODO destroy systems

    return "done"
Exemplo n.º 6
0
def select_feature(input_file_name,output_file_name,model_dict):
   
    #默认记录前后五步的历史
    history_size = 5
    future_size = history_size
    history_queue = queue(history_size)
    feature_queue = queue(future_size)
    #查看配置文件中的特征值
    unigram = model_dict["U"]
    input_file = codecs.open(input_file_name,"r","utf-8")
    output_file = codecs.open(output_file_name,"w",'utf-8')
    input_data = input_file.readlines()
    #保证进度用
    progress=0
    lens = len(input_data)
    #记录当前预读到了第几行
    #初始化为指向第一个汉字
    idx = 1
    #从这个对象里面进行特征的编码和解码
    id_factory = feature_id_factory()
    #每一行存储一个特征字典
    line_feature_dict = {}
    
    total_line_feature_dict = []
    total_feature_without_tags = []
    #没有状态信息的特征
    feature_without_tags = []
    #全局特征,用来存储特征函数
    total_feature_dict = {}
    enter_another_line = False
    for line in input_data:
        words = line.strip().split()
        #print print_unicode(line)
        #print print_unicode(words)
        if words[0]=='<BOS>':
            idx+=1
            #预读取几行到feature_queue
            enter_another_line = False
            history_queue.clear()
            feature_queue.clear()
            for j in xrange(future_size-1):
                if(idx>=lens): break
               
                cur_words = input_data[idx].strip().split()
                
                #注意,进入此循环,代表一定发生了预读
                idx += 1
                if(cur_words[0] == '<EOS>'):
                    enter_another_line = True   
                    break
                feature_queue.push(input_data[idx-1])
            continue
       
            
        if words[0]=="<EOS>":
            feature_queue.clear()
            history_queue.clear()
            if(len(feature_without_tags) == 0):
                continue
            total_line_feature_dict.append(line_feature_dict)
            total_feature_without_tags.append(feature_without_tags)
            line_feature_dict = {}
            feature_without_tags = []
            
            #当读到此处的时候,idx肯定指向了下一个<BOS> enter_another_line肯定为True
            #所以要idx+1,跳过读取下一个<BOS>
            idx+=1
            continue
          #预读下一行的信息
        if(enter_another_line == False and idx<lens):
            cur_words = input_data[idx].strip().split()
            idx+=1
            if cur_words[0] != '<EOS>' and cur_words[0]!='<BOS>':
                feature_queue.push(input_data[idx-1])
                
            if(cur_words[0] == '<EOS>'):
                enter_another_line = True
        progress+=1
      
        #输出进度,可以调整后面的值,否则变化会太快或者太慢
        if(progress%100000==0):
            print "extract feature complete %d%%" %round((progress*100/lens+1))
        feature_list,total_feature_without_tag =  utils.encode_feature(words,unigram,history_queue,feature_queue,id_factory)
              #!!important
        feature_queue.pop()
        for cur_feature in feature_list:
            if(cur_feature not in line_feature_dict):
                line_feature_dict[cur_feature] = 1
            else:
                line_feature_dict[cur_feature]+=1
            if(cur_feature not in total_feature_dict):
                total_feature_dict[cur_feature] = 1
            else:
                total_feature_dict[cur_feature] += 1
        feature_without_tags.append(total_feature_without_tag)
        history_queue.push(words)
#     print 'writing features to file...'
#     output_file.write(str(len(feature_dict))+"\n")
#     for key in feature_dict:
#         output_file.write(key+"\t"+str(feature_dict[key])+"\n")
    

    print 'writing features into disk ...'
    total_feature_dict_filename = 'output/total_feature_with_tags.txt'
    total_feature_dict_file = codecs.open(total_feature_dict_filename,"w","utf-8")
    no = 0
    for key in total_feature_dict:
        if total_feature_dict[key] >= threshold:
            total_feature_dict_file.write(key+"\t"+str(no)+"\n")
            no+=1
    for i,item in enumerate(total_line_feature_dict):
        output_file.write("<FEA>\n")
        for item in total_feature_without_tags[i]:
            output_file.write(str(item)+"\t")
        output_file.write("\n")
        for key in total_line_feature_dict[i]:
            if(total_feature_dict[key] >= threshold):
                output_file.write(str(key)+"\t"+str(total_line_feature_dict[i][key])+"\n")
        output_file.write("<SEP>\n")
    total_feature_dict_file.close()
    output_file.close()
    id_factory.write_file("output/feature_dict.txt")  
    print 'extract feature complete!!'     
Exemplo n.º 7
0
def main(tick, config, q):
    # On average, per tick:
    # Colony produce (10 * (habitability-population/10)) local goods per population
    # Colony produce (10 * richness * weight) every material per industry
    # Colony upkeeps 1 unique exotic good or 4 local goods per population
    # Colony upkeeps 1 unique exotic good or 4 solids per industry
    def produce(_id):
        batch = 2 * config['batchEconomy']
        colony = cp.query(payload="\
            SELECT goods, storage, industry, population, untilJoins\
            FROM massive\
            WHERE _id == '"+_id+"'")['results'][0]

        # produce_goods
        goodsPerPop = float(colony['untilJoins']['habitability']) - float(colony['population']) * config['popDAR']
        goodsTotal = batch * float(colony['population']) * goodsPerPop
        colony['storage']['goods'][colony['goods']] += int(goodsTotal)

        # produce_materials
        materialsPerInd = float(colony['untilJoins']['richness']) - float(colony['industry']) * config['indDAR']
        materialsTotal = batch * float(colony['industry']) * materialsPerInd
        colony['storage']['solids'] += int(materialsTotal * float(colony['untilJoins']['materials'][0]))
        colony['storage']['metals'] += int(materialsTotal * float(colony['untilJoins']['materials'][1]))
        colony['storage']['isotopes'] += int(materialsTotal * float(colony['untilJoins']['materials'][2]))

        r = cp.query(payload="\
            UPDATE massive['" + _id + "']\
            SET storage.goods['"+colony['goods']+"'] = "+str(colony['storage']['goods'][colony['goods']])+",\
                storage['solids'] = "+str(colony['storage']['solids'])+",\
                storage['metals'] = "+str(colony['storage']['metals'])+",\
                storage['isotopes'] = "+str(colony['storage']['isotopes'])+"\
        ")['results'][0]['_id']
        return 'economy: produce at ' + r

    def upkeep(_id):
        batch = 2 * config['batchEconomy']
        colony = cp.query(payload="\
            SELECT goods, storage, industry, population, untilJoins\
            FROM massive\
            WHERE _id == '"+_id+"'")['results'][0]

        # upkeep_population
        size = colony['population']
        for key, amount in colony['storage']['goods'].items():
            if size == 0:
                break
            if colony['storage']['goods'][key] >= batch and random.random() >= config['popForeignGoods']:
                colony['storage']['goods'][key] -= batch
                size -= 1
        if colony['storage']['goods'][colony['goods']] >= size * config['popLocalPenalty'] * batch:
            colony['storage']['goods'][colony['goods']] -= size * config['popLocalPenalty'] * batch
        else:
            colony['population'] -= 1
            colony['storage']['goods'][colony['goods']] += config['popDowngradeRefund'] * config['popUpgradeGoods']
            colony['storage']['solids'] += config['popDowngradeRefund'] * config['popUpgradeSolids']

        # upkeep_industry
        size = colony['industry']
        for key, amount in colony['storage']['goods'].items():
            if size == 0:
                break
            if colony['storage']['goods'][key] >= batch and random.random() >= config['indForeignGoods']:
                colony['storage']['goods'][key] -= batch
                size -= 1
        if colony['storage']['solids'] >= size * config['indLocalPenalty'] * batch:
            colony['storage']['solids'] -= size * config['indLocalPenalty'] * batch
        else:
            colony['industry'] -= 1
            colony['storage']['goods'][colony['goods']] += config['indDowngradeRefund'] * config['indUpgradeGoods']
            colony['storage']['solids'] += config['indDowngradeRefund'] * config['indUpgradeSolids']

        r = cp.query(payload="\
            UPDATE massive['" + _id + "']\
            SET storage['goods'] = "+json.dumps(colony['storage']['goods'])+",\
                storage['solids'] = "+str(int(colony['storage']['solids']))+",\
                storage['metals'] = "+str(int(colony['storage']['metals']))+",\
                storage['isotopes'] = "+str(int(colony['storage']['isotopes']))+",\
                population = "+str(int(colony['population']))+",\
                industy = "+str(int(colony['industry']))+"\
        ")['results'][0]['_id']
        return 'economy: upkeep at ' + r

    # Whom production or upkeep should happen this tick?
    r = cp.query(payload="\
        SELECT _id FROM massive\
        WHERE object == 'colony' && faction && population > 0 && Math.random()<"+str(1/config['batchEconomy'])+"\
        LIMIT 0, 999999")

    if int(r['hits']) > 0:
        # Iterate through colonies
        for lucky in r['results']:
            if random.random() < 0.5:
                utils.queue(produce, lucky['_id'])
            else:
                utils.queue(upkeep, lucky['_id'])

    return 'done'
Exemplo n.º 8
0
    def train_controller(self, old_model, new_model, optimizer, device,
                         train_loader, valid_loader, epoch, momentum,
                         entropy_weight, child_retrain_epoch,
                         child_retrain_interval):

        new_model.train()

        step = 0

        prev_runs = torch.zeros([5])  # to store the val_acc of prev epochs

        for epoch_idx in range(epoch):
            loss = torch.FloatTensor([0])

            epoch_valacc = torch.zeros(self.num_of_children)
            epoch_childs = []

            for child_idx in range(self.num_of_children):
                # images, labels.cuda()

                step += 1

                old_model()  # forward pass without input
                sampled_architecture = old_model.sampled_architecture
                sampled_entropies = old_model.sampled_entropies.detach()
                sampled_logprobs = old_model.sampled_logprobs

                # get the acc of a single child

                # make child
                conf = self.make_enas_config(sampled_architecture)
                epoch_childs.append(conf)

                print("CONF:", conf)

                if self.isShared:
                    child = self.child.to(device)
                else:
                    child = SharedEnasChild(
                        conf,
                        self.num_layers,
                        self.learning_rate_child,
                        momentum,
                        num_classes=self.num_classes,
                        out_filters=self.out_filters,
                        input_shape=self.input_shape,
                        input_channels=self.input_channels).to(device)

                #               self.logger.info("train_controller, epoch/child : ", epoch_idx, child_idx, " child : ", conf) # logging error

                # Train child
                self.train_child(child, conf, device, train_loader,
                                 self.epoch_child, epoch_idx, child_idx)

                # Test child
                validation_accuracy, validation_loss = self.test_child(
                    child, conf, device, valid_loader)
                epoch_valacc[child_idx] = validation_accuracy
                self.bestchilds.add(conf, validation_accuracy)

                reward = torch.tensor(validation_accuracy).detach()
                reward += sampled_entropies * entropy_weight

                # calculate advantage with baseline (moving avg)

                baseline = prev_runs.mean(
                )  # substract baseline to reduce variance in rewards
                reward = reward - baseline

                old_model.memory.add_rewards(reward)

                # logging to tensorboard
                self.writer.add_scalar("reward", reward, global_step=step)
                self.writer.add_scalar("valid_acc",
                                       validation_accuracy,
                                       global_step=step)
                self.writer.add_scalar("valid_loss",
                                       validation_loss,
                                       global_step=step)
                self.writer.add_scalar("sampled_entropies",
                                       sampled_entropies,
                                       global_step=step)
                self.writer.add_scalar("sampled_logprobs",
                                       sampled_logprobs,
                                       global_step=step)

            best_child_idx = torch.argmax(epoch_valacc)
            best_child_conf = epoch_childs[best_child_idx]

            message = " best valacc" + str(epoch_valacc[best_child_idx].item())\
                      + ' - config: ' + str(best_child_conf)

            self.writer.add_text("best child", message, global_step=epoch_idx)

            if epoch_idx % child_retrain_interval == 0:

                retrained_valacc, retrained_loss = self.retrain(
                    best_child_conf, device, train_loader, valid_loader,
                    child_retrain_epoch, epoch_idx)

                print("current best childs: ", self.bestchilds.bestchilds)
                self.writer.add_scalar("retrainerd child valacc",
                                       retrained_valacc, epoch_idx)

                self.save(epoch_idx)

            old_branch_logprobs, old_skip_logprobs = old_model.memory.get_logprobs(
            )
            old_rewards = old_model.memory.rewards

            old_skip_logprobs = torch.tensor(
                old_skip_logprobs,
                dtype=torch.float).detach()  # NO GRADIENT Info
            old_branch_logprobs = torch.tensor(
                old_branch_logprobs,
                dtype=torch.float).detach()  # NO GRADIENT Info

            # #Optimize policy for K epochs:

            for _ in range(self.K_epochs):

                #print("inside ppo update loop")

                new_branch_logprobs = torch.zeros_like(old_branch_logprobs)
                new_skip_logprobs = torch.zeros_like(old_skip_logprobs)
                branch_entropies = torch.zeros_like(old_skip_logprobs)
                skip_entropies = torch.zeros_like(old_skip_logprobs)

                #print("branchhsape", new_branch_logprobs.shape)
                #print("skiphape", new_skip_logprobs.shape)
                #print("branch entropyshape", branch_entropies.shape)
                #print("skip entropyshape", skip_entropies.shape)

                for tx_idx in range(len(old_model.memory.rewards)):
                    #print(new_model.evaluate(old_model.memory.transitions[tx_idx]))
                    new_branch_logprobs[tx_idx], new_skip_logprobs[
                        tx_idx], branch_entropies[tx_idx], skip_entropies[
                            tx_idx] = new_model.evaluate(
                                old_model.memory.transitions[tx_idx])

                # Finding the ratio (pi_theta / pi_theta__old):
                branch_ratios = torch.exp(new_branch_logprobs -
                                          old_branch_logprobs)
                skip_ratios = torch.exp(new_skip_logprobs - old_skip_logprobs)
                ratios = torch.cat((branch_ratios, skip_ratios))

                advantages = torch.Tensor(old_rewards).mean()
                ratios = ratios.mean()
                #print( ratios, advantages)

                surr1 = ratios * advantages
                #print(surr1)

                surr2 = torch.clamp(ratios, 1 - 0.2, 1 + 0.2) * advantages
                loss = -torch.min(
                    surr1, surr2
                )  # + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy
                #print(loss)

                # take gradient step
                optimizer.zero_grad()
                loss.mean().backward(retain_graph=True)
                optimizer.step()

            # Copy new weights into old policy:
            old_model.memory.clean()
            old_model.load_state_dict(new_model.state_dict())

            # self.writer.add_histogram("sampled_branches", model.sampled_architecture, global_step=epoch_idx)
            # self.writer.add_histogram("sampled_connections", model.sampled_architecture[1], global_step=epoch_idx)

            self.writer.add_scalar("epoch_loss",
                                   loss.mean().item(),
                                   global_step=epoch_idx)
            self.writer.add_scalar("epoch mean validation acc.",
                                   epoch_valacc.mean(),
                                   global_step=epoch_idx)

            # self.writer.add_graph(child) #ERROR:  TracedModules don't support parameter sharing between modules

            prev_runs = queue(prev_runs, epoch_valacc.mean())

        return prev_runs