def __init__(self, model, args, device_id=None, verbose=False): self.model = model self.args = args self.device_id = device_id self.verbose = verbose self.opt = optim.SGD(self.model.parameters(), lr=self.args.lr, momentum=self.args.momentum, nesterov=(self.args.nesterov and self.args.momentum > 0)) self.crit = Criterion(self.model.word_dict, device_id=device_id) self.sel_crit = Criterion(self.model.item_dict, device_id=device_id, bad_toks=['<disconnect>', '<disagree>']) if self.args.visual: self.model_plot = vis.ModulePlot(self.model, plot_weight=False, plot_grad=True) self.loss_plot = vis.Plot(['train', 'valid', 'valid_select'], 'loss', 'loss', 'epoch', running_n=1) self.ppl_plot = vis.Plot(['train', 'valid', 'valid_select'], 'perplexity', 'ppl', 'epoch', running_n=1)
def __init__(self, model, args, verbose=False): self.model = model self.args = args self.verbose = verbose self.opt = self.make_opt(self.args.lr) self.crit = Criterion(self.model.word_dict) self.sel_crit = Criterion( self.model.item_dict, bad_toks=['<disconnect>', '<disagree>']) if self.args.visual: self.model_plot = vis.ModulePlot(self.model, plot_weight=True, plot_grad=False) self.loss_plot = vis.Plot(['train', 'valid', 'valid_select'], 'loss', 'loss', 'epoch', running_n=1, write_to_file=False) self.ppl_plot = vis.Plot(['train', 'valid', 'valid_select'], 'perplexity', 'ppl', 'epoch', running_n=1, write_to_file=False)
def __init__(self, model, args, name='Alice'): super(RlAgent, self).__init__(model, args, name=name) self.opt = optim.SGD( self.model.parameters(), lr=self.args.rl_lr, momentum=self.args.momentum, nesterov=(self.args.nesterov and self.args.momentum > 0)) self.all_rewards = [] if self.args.visual: self.model_plot = vis.ModulePlot(self.model, plot_weight=False, plot_grad=True) self.reward_plot = vis.Plot(['reward',], 'reward', 'reward') self.loss_plot = vis.Plot(['loss',], 'loss', 'loss') self.t = 0 # Explicitly activate training_mode to avoid runtime error with pytorch > 0.4.0 self.model.train()
def __init__(self, model, args, name='Alice'): super(RlAgent, self).__init__(model, args, name=name) self.opt = optim.SGD(self.model.parameters(), lr=self.args.rl_lr, momentum=self.args.momentum, nesterov=(self.args.nesterov and self.args.momentum > 0)) self.all_rewards = [] if self.args.visual: self.model_plot = vis.ModulePlot(self.model, plot_weight=False, plot_grad=True) self.reward_plot = vis.Plot([ 'reward', ], 'reward', 'reward') self.loss_plot = vis.Plot([ 'loss', ], 'loss', 'loss') self.t = 0