def __init__( self, config: Config, ): self._config = config self._accumulation_step_count = \ config.get('prooftrace_lm_accumulation_step_count') self._learning_rate = \ config.get('prooftrace_lm_learning_rate') self._value_coeff = config.get('prooftrace_lm_value_coeff') self._device = torch.device(config.get('device')) self._save_dir = config.get('prooftrace_save_dir') self._load_dir = config.get('prooftrace_load_dir') self._tb_writer = None if self._config.get('tensorboard_log_dir'): if self._config.get('distributed_rank') == 0: self._tb_writer = SummaryWriter( self._config.get('tensorboard_log_dir'), ) self._inner_model_E = E(self._config).to(self._device) self._inner_model_H = H(self._config).to(self._device) self._inner_model_PH = PH(self._config).to(self._device) self._inner_model_VH = VH(self._config).to(self._device) Log.out( "Initializing prooftrace LanguageModel", { 'parameter_count_E': self._inner_model_E.parameters_count(), 'parameter_count_H': self._inner_model_H.parameters_count(), 'parameter_count_PH': self._inner_model_PH.parameters_count(), 'parameter_count_VH': self._inner_model_VH.parameters_count(), }, ) self._model_E = self._inner_model_E self._model_H = self._inner_model_H self._model_PH = self._inner_model_PH self._model_VH = self._inner_model_VH self._nll_loss = nn.NLLLoss() self._mse_loss = nn.MSELoss() self._train_batch = 0
def __init__( self, config: Config, ): self._config = config self._device = torch.device(config.get('device')) self._load_dir = config.get('prooftrace_load_dir') with gzip.open( os.path.join( os.path.expanduser(config.get('prooftrace_dataset_dir')), config.get('prooftrace_dataset_size'), 'traces.tokenizer', ), 'rb') as f: self._tokenizer = pickle.load(f) self._model_E = E(self._config).to(self._device) self._model_HV = H(self._config).to(self._device) self._model_HP = H(self._config).to(self._device) self._model_PH = PH(self._config).to(self._device) self._model_VH = VH(self._config).to(self._device)
def __init__( self, config: Config, modules: typing.Dict[str, torch.nn.Module] = None, ): self._config = config self._device = torch.device(config.get('device')) if modules is not None: assert 'pE' in modules assert 'pT' in modules assert 'pH' in modules self._modules = modules else: self._modules = { 'pE': E(self._config).to(self._device), 'pT': T(self._config).to(self._device), 'pH': PH(self._config).to(self._device), }
def __init__( self, config: Config, train_dataset: ProofTraceLMDataset, ): self._config = config self._action_coeff = config.get('prooftrace_lm_action_coeff') self._value_coeff = config.get('prooftrace_lm_value_coeff') self._device = torch.device(config.get('device')) self._modules = { 'E': E(self._config).to(self._device), 'T': T(self._config).to(self._device), 'PH': PH(self._config).to(self._device), 'VH': VH(self._config).to(self._device), } self._ack = IOTAAck( config.get('prooftrace_lm_iota_sync_dir'), self._modules, ) self._nll_loss = nn.NLLLoss() self._mse_loss = nn.MSELoss() self._train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=self._config.get('prooftrace_lm_batch_size'), shuffle=True, collate_fn=lm_collate, ) Log.out('ACK initialization', { "batch_size": self._config.get('prooftrace_lm_batch_size'), }) self._train_batch = 0
def __init__( self, config: Config, ): self._config = config self._device = torch.device(config.get('device')) self._modules = { 'E': E(self._config).to(self._device), 'T': T(self._config).to(self._device), 'PH': PH(self._config).to(self._device), 'VH': VH(self._config).to(self._device), } self._dataset_dir = os.path.join( os.path.expanduser(config.get('prooftrace_dataset_dir')), config.get('prooftrace_dataset_size'), 'test_traces', ) with gzip.open( os.path.join( os.path.expanduser(config.get('prooftrace_dataset_dir')), config.get('prooftrace_dataset_size'), 'traces.tokenizer', ), 'rb') as f: self._tokenizer = pickle.load(f) self._tst = IOTARun( config.get('prooftrace_search_iota_sync_dir'), 'test', self._modules, ) self._test_gamma_size = config.get('prooftrace_search_test_gamma_size') Log.out('TST initialization', {})
def __init__( self, config: Config, ): self._config = config self._learning_rate = config.get('prooftrace_lm_learning_rate') self._min_update_count = \ config.get('prooftrace_lm_iota_min_update_count') self._device = torch.device(config.get('device')) self._save_dir = config.get('prooftrace_save_dir') self._load_dir = config.get('prooftrace_load_dir') self._epoch = 0 self._tb_writer = None if self._config.get('tensorboard_log_dir'): self._tb_writer = SummaryWriter( self._config.get('tensorboard_log_dir'), ) self._modules = { 'E': E(self._config).to(self._device), 'T': T(self._config).to(self._device), 'PH': PH(self._config).to(self._device), 'VH': VH(self._config).to(self._device), } Log.out( "SYN Initializing", { 'parameter_count_E': self._modules['E'].parameters_count(), 'parameter_count_T': self._modules['T'].parameters_count(), 'parameter_count_PH': self._modules['PH'].parameters_count(), 'parameter_count_VH': self._modules['VH'].parameters_count(), }, ) self._syn = IOTASyn( config.get('prooftrace_lm_iota_sync_dir'), self._modules, ) self._optimizer = optim.Adam( [ { 'params': self._modules['E'].parameters() }, { 'params': self._modules['T'].parameters() }, { 'params': self._modules['PH'].parameters() }, { 'params': self._modules['VH'].parameters() }, ], lr=self._learning_rate, ) self._syn.broadcast({'config': self._config})
class Model: def __init__( self, config: Config, ): self._config = config self._device = torch.device(config.get('device')) self._load_dir = config.get('prooftrace_load_dir') with gzip.open( os.path.join( os.path.expanduser(config.get('prooftrace_dataset_dir')), config.get('prooftrace_dataset_size'), 'traces.tokenizer', ), 'rb') as f: self._tokenizer = pickle.load(f) self._model_E = E(self._config).to(self._device) self._model_HV = H(self._config).to(self._device) self._model_HP = H(self._config).to(self._device) self._model_PH = PH(self._config).to(self._device) self._model_VH = VH(self._config).to(self._device) def load(self, ): if self._load_dir: if os.path.isfile(self._load_dir + "/model_E_0.pt"): Log.out("Loading prooftrace", { 'load_dir': self._load_dir, }) self._model_E.load_state_dict( torch.load( self._load_dir + "/model_E_0.pt", map_location=self._device, ), ) self._model_HV.load_state_dict( torch.load( self._load_dir + "/model_HV_0.pt", map_location=self._device, ), ) self._model_HP.load_state_dict( torch.load( self._load_dir + "/model_HP_0.pt", map_location=self._device, ), ) self._model_PH.load_state_dict( torch.load( self._load_dir + "/model_PH_0.pt", map_location=self._device, ), ) self._model_VH.load_state_dict( torch.load( self._load_dir + "/model_VH_0.pt", map_location=self._device, ), ) self._model_E.eval() self._model_HV.eval() self._model_HP.eval() self._model_PH.eval() self._model_VH.eval() return self def infer( self, trc: typing.List[typing.List[Action]], idx: typing.List[typing.List[int]], ) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ]: with torch.no_grad(): embeds = self._model_E(trc) hiddens_v = self._model_HV(embeds) hiddens_p = self._model_HP(embeds) head_v = torch.cat( [hiddens_v[i][idx[i]].unsqueeze(0) for i in range(len(idx))], dim=0) head_p = torch.cat( [hiddens_p[i][idx[i]].unsqueeze(0) for i in range(len(idx))], dim=0) targets = torch.cat( [embeds[i][0].unsqueeze(0) for i in range(len(idx))], dim=0) prd_actions, prd_lefts, prd_rights = \ self._model_PH(head_p, targets) prd_values = self._model_VH(head_v, targets) return ( prd_actions, prd_lefts, prd_rights, prd_values, )
class LanguageModel: def __init__( self, config: Config, ): self._config = config self._accumulation_step_count = \ config.get('prooftrace_lm_accumulation_step_count') self._learning_rate = \ config.get('prooftrace_lm_learning_rate') self._value_coeff = config.get('prooftrace_lm_value_coeff') self._device = torch.device(config.get('device')) self._save_dir = config.get('prooftrace_save_dir') self._load_dir = config.get('prooftrace_load_dir') self._tb_writer = None if self._config.get('tensorboard_log_dir'): if self._config.get('distributed_rank') == 0: self._tb_writer = SummaryWriter( self._config.get('tensorboard_log_dir'), ) self._inner_model_E = E(self._config).to(self._device) self._inner_model_H = H(self._config).to(self._device) self._inner_model_PH = PH(self._config).to(self._device) self._inner_model_VH = VH(self._config).to(self._device) Log.out( "Initializing prooftrace LanguageModel", { 'parameter_count_E': self._inner_model_E.parameters_count(), 'parameter_count_H': self._inner_model_H.parameters_count(), 'parameter_count_PH': self._inner_model_PH.parameters_count(), 'parameter_count_VH': self._inner_model_VH.parameters_count(), }, ) self._model_E = self._inner_model_E self._model_H = self._inner_model_H self._model_PH = self._inner_model_PH self._model_VH = self._inner_model_VH self._nll_loss = nn.NLLLoss() self._mse_loss = nn.MSELoss() self._train_batch = 0 def init_training( self, train_dataset, ): if self._config.get('distributed_training'): self._model_E = torch.nn.parallel.DistributedDataParallel( self._inner_model_E, device_ids=[self._device], ) self._model_H = torch.nn.parallel.DistributedDataParallel( self._inner_model_H, device_ids=[self._device], ) self._model_PH = torch.nn.parallel.DistributedDataParallel( self._inner_model_PH, device_ids=[self._device], ) self._model_VH = torch.nn.parallel.DistributedDataParallel( self._inner_model_VH, device_ids=[self._device], ) self._optimizer = optim.Adam( [ { 'params': self._model_E.parameters() }, { 'params': self._model_H.parameters() }, { 'params': self._model_PH.parameters() }, { 'params': self._model_VH.parameters() }, ], lr=self._learning_rate, ) self._train_sampler = None if self._config.get('distributed_training'): self._train_sampler = \ torch.utils.data.distributed.DistributedSampler( train_dataset, ) batch_size = \ self._config.get('prooftrace_lm_batch_size') // \ self._accumulation_step_count self._train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=(self._train_sampler is None), sampler=self._train_sampler, collate_fn=lm_collate, num_workers=8, ) Log.out( 'Training initialization', { "accumulation_step_count": self._accumulation_step_count, "world_size": self._config.get('distributed_world_size'), "batch_size": self._config.get('prooftrace_lm_batch_size'), "dataloader_batch_size": batch_size, "effective_batch_size": (self._config.get('prooftrace_lm_batch_size') * self._config.get('distributed_world_size')), }) def init_testing( self, test_dataset, ): batch_size = \ self._config.get('prooftrace_lm_batch_size') // \ self._accumulation_step_count self._test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=batch_size, shuffle=True, collate_fn=lm_collate, num_workers=8, ) def load( self, training=True, ): rank = self._config.get('distributed_rank') if self._load_dir: if os.path.isfile(self._load_dir + "/model_E_{}.pt".format(rank)): Log.out("Loading prooftrace", { 'load_dir': self._load_dir, }) self._inner_model_E.load_state_dict( torch.load( self._load_dir + "/model_E_{}.pt".format(rank), map_location=self._device, ), ) self._inner_model_H.load_state_dict( torch.load( self._load_dir + "/model_H_{}.pt".format(rank), map_location=self._device, ), ) self._inner_model_PH.load_state_dict( torch.load( self._load_dir + "/model_PH_{}.pt".format(rank), map_location=self._device, ), ) self._inner_model_VH.load_state_dict( torch.load( self._load_dir + "/model_VH_{}.pt".format(rank), map_location=self._device, ), ) if training: if os.path.isfile(self._load_dir + "/optimizer_{}.pt".format(rank)): self._optimizer.load_state_dict( torch.load( self._load_dir + "/optimizer_{}.pt".format(rank), map_location=self._device, ), ) return self def save(self, ): rank = self._config.get('distributed_rank') if self._save_dir: Log.out("Saving prooftrace models", { 'save_dir': self._save_dir, }) torch.save( self._inner_model_E.state_dict(), self._save_dir + "/model_E_{}.pt".format(rank), ) torch.save( self._inner_model_H.state_dict(), self._save_dir + "/model_H_{}.pt".format(rank), ) torch.save( self._inner_model_PH.state_dict(), self._save_dir + "/model_PH_{}.pt".format(rank), ) torch.save( self._inner_model_VH.state_dict(), self._save_dir + "/model_VH_{}.pt".format(rank), ) torch.save( self._optimizer.state_dict(), self._save_dir + "/optimizer_{}.pt".format(rank), ) def update(self, ) -> None: update = self._config.update() if update: if 'prooftrace_lm_learning_rate' in update: lr = \ self._config.get('prooftrace_lm_learning_rate') if lr != self._learning_rate: self._learning_rate = lr for group in self._optimizer.param_groups: group['lr'] = lr Log.out("Updated", { "prooftrace_learning_rate": lr, }) if 'prooftrace_lm_value_coeff' in update: coeff = self._config.get('prooftrace_lm_value_coeff') if coeff != self._value_coeff: self._value_coeff = coeff Log.out("Updated", { "prooftrace_lm_value_coeff": coeff, }) if self._tb_writer is not None: for k in update: if k in [ 'prooftrace_lm_learning_rate', 'prooftrace_lm_value_coeff', ]: self._tb_writer.add_scalar( "prooftrace_lm_train_run/{}".format(k), update[k], self._train_batch, ) def batch_train( self, epoch, ): assert self._train_loader is not None self._model_E.train() self._model_H.train() self._model_PH.train() self._model_VH.train() act_loss_meter = Meter() lft_loss_meter = Meter() rgt_loss_meter = Meter() # val_loss_meter = Meter() if self._config.get('distributed_training'): self._train_sampler.set_epoch(epoch) for it, (idx, act, arg, trh, val) in enumerate(self._train_loader): action_embeds = self._model_E(act) argument_embeds = self._model_E(arg) # action_embeds = \ # torch.zeros(action_embeds.size()).to(self._device) # argument_embeds = \ # torch.zeros(argument_embeds.size()).to(self._device) hiddens = self._model_H(action_embeds, argument_embeds) heads = torch.cat( [hiddens[i][idx[i]].unsqueeze(0) for i in range(len(idx))], dim=0) targets = torch.cat( [action_embeds[i][0].unsqueeze(0) for i in range(len(idx))], dim=0) actions = torch.tensor( [trh[i].value - len(PREPARE_TOKENS) for i in range(len(trh))], dtype=torch.int64).to(self._device) lefts = torch.tensor( [arg[i].index(trh[i].left) for i in range(len(trh))], dtype=torch.int64).to(self._device) rights = torch.tensor( [arg[i].index(trh[i].right) for i in range(len(trh))], dtype=torch.int64).to(self._device) # values = torch.tensor(val).unsqueeze(1).to(self._device) prd_actions, prd_lefts, prd_rights = \ self._model_PH(heads, hiddens, targets) # prd_values = self._model_VH(heads, targets) act_loss = self._nll_loss(prd_actions, actions) lft_loss = self._nll_loss(prd_lefts, lefts) rgt_loss = self._nll_loss(prd_rights, rights) # val_loss = self._mse_loss(prd_values, values) # (act_loss + lft_loss + rgt_loss + # self._value_coeff * val_loss).backward() (act_loss + lft_loss + rgt_loss).backward() if it % self._accumulation_step_count == 0: self._optimizer.step() self._optimizer.zero_grad() act_loss_meter.update(act_loss.item()) lft_loss_meter.update(lft_loss.item()) rgt_loss_meter.update(rgt_loss.item()) # val_loss_meter.update(val_loss.item()) Log.out( "TRAIN BATCH", { 'train_batch': self._train_batch, 'act_loss_avg': "{:.4f}".format(act_loss.item()), 'lft_loss_avg': "{:.4f}".format(lft_loss.item()), 'rgt_loss_avg': "{:.4f}".format(rgt_loss.item()), # 'val_loss_avg': "{:.4f}".format(val_loss.item()), }) if self._train_batch % 10 == 0 and self._train_batch != 0: Log.out( "PROOFTRACE TRAIN", { 'epoch': epoch, 'train_batch': self._train_batch, 'act_loss_avg': "{:.4f}".format(act_loss_meter.avg), 'lft_loss_avg': "{:.4f}".format(lft_loss_meter.avg), 'rgt_loss_avg': "{:.4f}".format(rgt_loss_meter.avg), # 'val_loss_avg': "{:.4f}".format(val_loss_meter.avg), }) if self._tb_writer is not None: self._tb_writer.add_scalar( "prooftrace_lm_train/act_loss", act_loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "prooftrace_lm_train/lft_loss", lft_loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "prooftrace_lm_train/rgt_loss", rgt_loss_meter.avg, self._train_batch, ) # self._tb_writer.add_scalar( # "prooftrace_lm_train/val_loss", # val_loss_meter.avg, self._train_batch, # ) act_loss_meter = Meter() lft_loss_meter = Meter() rgt_loss_meter = Meter() # val_loss_meter = Meter() if self._train_batch % 1000 == 0: self.save() self.test() self._model_E.train() self._model_H.train() self._model_PH.train() self._model_VH.train() self.update() self._train_batch += 1 Log.out("EPOCH DONE", { 'epoch': epoch, }) def test(self, ): assert self._test_loader is not None self._model_E.eval() self._model_H.eval() self._model_PH.eval() self._model_VH.eval() act_loss_meter = Meter() lft_loss_meter = Meter() rgt_loss_meter = Meter() # val_loss_meter = Meter() test_batch = 0 with torch.no_grad(): for it, (idx, act, arg, trh, val) in enumerate(self._test_loader): action_embeds = self._model_E(act) argument_embeds = self._model_E(arg) hiddens = self._model_H(action_embeds, argument_embeds) heads = torch.cat( [hiddens[i][idx[i]].unsqueeze(0) for i in range(len(idx))], dim=0) targets = torch.cat([ action_embeds[i][0].unsqueeze(0) for i in range(len(idx)) ], dim=0) actions = torch.tensor([ trh[i].value - len(PREPARE_TOKENS) for i in range(len(trh)) ], dtype=torch.int64).to(self._device) lefts = torch.tensor( [arg[i].index(trh[i].left) for i in range(len(trh))], dtype=torch.int64).to(self._device) rights = torch.tensor( [arg[i].index(trh[i].right) for i in range(len(trh))], dtype=torch.int64).to(self._device) # values = torch.tensor(val).unsqueeze(1).to(self._device) prd_actions, prd_lefts, prd_rights = \ self._model_PH(heads, hiddens, targets) # prd_values = self._model_VH(heads, targets) act_loss = self._nll_loss(prd_actions, actions) lft_loss = self._nll_loss(prd_lefts, lefts) rgt_loss = self._nll_loss(prd_rights, rights) # val_loss = self._mse_loss(prd_values, values) act_loss_meter.update(act_loss.item()) lft_loss_meter.update(lft_loss.item()) rgt_loss_meter.update(rgt_loss.item()) # val_loss_meter.update(val_loss.item()) Log.out( "TEST BATCH", { 'train_batch': self._train_batch, 'test_batch': test_batch, 'act_loss_avg': "{:.4f}".format(act_loss.item()), 'lft_loss_avg': "{:.4f}".format(lft_loss.item()), 'rgt_loss_avg': "{:.4f}".format(rgt_loss.item()), # 'val_loss_avg': "{:.4f}".format(val_loss.item()), }) test_batch += 1 Log.out( "PROOFTRACE TEST", { 'train_batch': self._train_batch, 'act_loss_avg': "{:.4f}".format(act_loss_meter.avg), 'lft_loss_avg': "{:.4f}".format(lft_loss_meter.avg), 'rgt_loss_avg': "{:.4f}".format(rgt_loss_meter.avg), # 'val_loss_avg': "{:.4f}".format(val_loss_meter.avg), }) if self._tb_writer is not None: self._tb_writer.add_scalar( "prooftrace_lm_test/act_loss", act_loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "prooftrace_lm_test/lft_loss", lft_loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "prooftrace_lm_test/rgt_loss", rgt_loss_meter.avg, self._train_batch, )
def verify(): parser = argparse.ArgumentParser(description="") parser.add_argument( 'config_path', type=str, help="path to the config file", ) parser.add_argument( '--load_dir', type=str, help="config override", ) parser.add_argument( '--device', type=str, help="config override", ) args = parser.parse_args() config = Config.from_file(args.config_path) if args.device is not None: config.override('device', args.device) if args.load_dir is not None: config.override( 'prooftrace_load_dir', os.path.expanduser(args.load_dir), ) device = torch.device(config.get('device')) load_dir = config.get('prooftrace_load_dir') modules = { 'E': E(config).to(device), 'PHI': H(config).to(device), 'VHI': H(config).to(device), 'PH': PH(config).to(device), 'VH': VH(config).to(device), } optimizer = optim.Adam([ { 'params': modules['E'].parameters() }, { 'params': modules['PHI'].parameters() }, { 'params': modules['VHI'].parameters() }, { 'params': modules['PH'].parameters() }, { 'params': modules['VH'].parameters() }, ], lr=config.get('prooftrace_ppo_learning_rate')) assert load_dir is not None modules['E'].load_state_dict( torch.load( load_dir + "/model_E.pt", map_location=device, )) modules['PHI'].load_state_dict( torch.load( load_dir + "/model_PHI.pt", map_location=device, )) modules['VHI'].load_state_dict( torch.load( load_dir + "/model_VHI.pt", map_location=device, )) modules['PH'].load_state_dict( torch.load( load_dir + "/model_PH.pt", map_location=device, )) modules['VH'].load_state_dict( torch.load( load_dir + "/model_VH.pt", map_location=device, )) optimizer.load_state_dict( torch.load( load_dir + "/optimizer.pt", map_location=device, )) Log.out('OK')
def run(): parser = argparse.ArgumentParser(description="") parser.add_argument( 'config_path', type=str, help="path to the config file", ) parser.add_argument( '--save_dir', type=str, help="config override", ) parser.add_argument( '--load_dir', type=str, help="config override", ) parser.add_argument( '--device', type=str, help="config override", ) args = parser.parse_args() config = Config.from_file(args.config_path) if args.device is not None: config.override('device', args.device) if args.load_dir is not None: config.override( 'prooftrace_load_dir', os.path.expanduser(args.load_dir), ) if args.save_dir is not None: config.override( 'prooftrace_save_dir', os.path.expanduser(args.save_dir), ) device = torch.device(config.get('device')) load_dir = config.get('prooftrace_load_dir') save_dir = config.get('prooftrace_save_dir') modules = { 'E': E(config).to(device), 'HP': H(config).to(device), 'HV': H(config).to(device), 'PH': PH(config).to(device), 'VH': VH(config).to(device), } optimizer = optim.Adam([ { 'params': modules['E'].parameters() }, { 'params': modules['HP'].parameters() }, { 'params': modules['PH'].parameters() }, { 'params': modules['VH'].parameters() }, ], lr=config.get('prooftrace_ppo_learning_rate')) modules['E'].load_state_dict( torch.load( load_dir + "/model_E.pt", map_location=device, )) modules['HP'].load_state_dict( torch.load( load_dir + "/model_H.pt", map_location=device, )) modules['HV'].load_state_dict( torch.load( load_dir + "/model_H.pt", map_location=device, )) modules['PH'].load_state_dict( torch.load( load_dir + "/model_PH.pt", map_location=device, )) modules['VH'].load_state_dict( torch.load( load_dir + "/model_VH.pt", map_location=device, )) optimizer.load_state_dict( torch.load( load_dir + "/optimizer.pt", map_location=device, )) new_params = copy.deepcopy(optimizer.param_groups[1]) optimizer.param_groups.insert(1, new_params) if save_dir: torch.save( modules['E'].state_dict(), save_dir + "/model_E.pt", ) torch.save( modules['HP'].state_dict(), save_dir + "/model_PHI.pt", ) torch.save( modules['HV'].state_dict(), save_dir + "/model_VHI.pt", ) torch.save( modules['PH'].state_dict(), save_dir + "/model_PH.pt", ) torch.save( modules['VH'].state_dict(), save_dir + "/model_VH.pt", ) torch.save( optimizer.state_dict(), save_dir + "/optimizer.pt", ) Log.out('DONE')
def __init__( self, config: Config, ): self._config = config self._rollout_size = config.get('prooftrace_ppo_rollout_size') self._pool_size = config.get('prooftrace_env_pool_size') self._epoch_count = config.get('prooftrace_ppo_epoch_count') self._clip = config.get('prooftrace_ppo_clip') self._grad_norm_max = config.get('prooftrace_ppo_grad_norm_max') self._act_entropy_coeff = \ config.get('prooftrace_ppo_act_entropy_coeff') self._ptr_entropy_coeff = \ config.get('prooftrace_ppo_ptr_entropy_coeff') self._value_coeff = config.get('prooftrace_ppo_value_coeff') self._learning_rate = config.get('prooftrace_ppo_learning_rate') self._reset_gamma = \ config.get('prooftrace_ppo_reset_gamma') self._fixed_gamma = \ config.get('prooftrace_ppo_fixed_gamma') self._explore_alpha = \ config.get('prooftrace_ppo_explore_alpha') self._explore_beta = \ config.get('prooftrace_ppo_explore_beta') self._explore_beta_width = \ config.get('prooftrace_ppo_explore_beta_width') self._step_reward_prob = \ config.get('prooftrace_ppo_step_reward_prob') self._match_reward_prob = \ config.get('prooftrace_ppo_match_reward_prob') self._device = torch.device(config.get('device')) self._modules = { 'E': E(self._config).to(self._device), 'T': T(self._config).to(self._device), 'PH': PH(self._config).to(self._device), 'VH': VH(self._config).to(self._device), } self._ack = IOTAAck( config.get('prooftrace_ppo_iota_sync_dir'), self._modules, ) self._rollouts = Rollouts(self._config) self._pool = Pool(self._config, False) self._rollouts.observations[0] = self._pool.reset( self._reset_gamma, self._fixed_gamma, ) self._episode_stp_reward = [0.0] * self._pool_size self._episode_mtc_reward = [0.0] * self._pool_size self._episode_fnl_reward = [0.0] * self._pool_size Log.out( 'ACK initialization', { "pool_size": self._config.get('prooftrace_env_pool_size'), "rollout_size": self._config.get('prooftrace_ppo_rollout_size'), "batch_size": self._config.get('prooftrace_ppo_batch_size'), })