def train(args): net = HypothesisNet(args) if not args.no_log: log = log_this(net.args, 'logs/sim', args.name, checkpoints=False) simulator = net.simulator if args.no_reservoir: layer = net.W_ro else: layer = net.reservoir batch_size = 10 criterion = nn.MSELoss() train_params = simulator.parameters() optimizer = optim.Adam(train_params, lr=args.lr) for i in range(args.iters): if not args.no_reservoir: layer.reset('random') optimizer.zero_grad() prop = torch.Tensor( np.random.normal(0, 5, size=(batch_size, net.args.D))) state = torch.Tensor( np.random.normal(0, 10, size=(batch_size, net.args.L))) sim_out = simulator(state, prop) # run reservoir 10 steps, so predict 10 steps in future outs = [] for j in range(args.forward_steps): outs.append(layer(prop)) actions = sum(outs) # get state output layer_out = actions + state # validation makes sure performance is poor if we use someone else's output layer_out_val = actions.roll(1, 0) + state # calculate euclidean loss loss = criterion(layer_out, sim_out) loss_val = criterion(layer_out_val, sim_out) loss.backward() optimizer.step() if i % 50 == 0 and i != 0: print(f'iteration: {i} | loss {loss} | loss_val {loss_val}') if not args.no_log: save_model_path = os.path.join(log.run_dir, f'model_{log.run_id}.pth') save_sim_path = os.path.join(log.run_dir, f'sim_{log.run_id}.pth') torch.save(net.state_dict(), save_model_path) torch.save(simulator.state_dict(), save_sim_path) print(f'saved model to {save_model_path}, sim to {save_sim_path}')
def train(args): net = HypothesisNet(args) if not args.no_log: log = log_this(net.args, 'logs/hyp', args.name, checkpoints=False) simulator = net.simulator hypothesizer = net.hypothesizer batch_size = 10 criterion = nn.MSELoss() train_params = hypothesizer.parameters() optimizer = optim.Adam(train_params, lr=1e-3) for i in range(args.iters): optimizer.zero_grad() state = torch.Tensor( np.random.normal(0, 10, size=(batch_size, net.args.L))) task = torch.Tensor( np.random.normal(0, 10, size=(batch_size, net.args.L))) prop = hypothesizer(state, task) sim_out = simulator(state, prop) # run reservoir 10 steps, so predict 10 steps in future outs = [] for j in range(10): outs.append(layer(prop)) actions = sum(outs) # get state output layer_out = actions + state # validation makes sure performance is poor if we use someone else's output layer_out_val = actions.roll(1, 0) + state # calculate euclidean loss diff = torch.norm(layer_out - sim_out, dim=1) loss = criterion(diff, torch.zeros_like(diff)) diff_val = torch.norm(layer_out_val - sim_out, dim=1) loss_val = criterion(diff_val, torch.zeros_like(diff_val)) loss.backward() optimizer.step() if i % 50 == 0 and i != 0: print(f'iteration: {i} | loss {loss} | loss_val {loss_val}') if not args.no_log: save_model_path = os.path.join(log.run_dir, f'model_{log.run_id}.pth') torch.save(net.state_dict(), save_model_path) print(f'saved model to {save_model_path}')
class Trainer: def __init__(self, args): super().__init__() self.args = args if self.args.net == 'basic': self.net = BasicNetwork(self.args) elif self.args.net == 'state': self.net = StateNet(self.args) elif self.args.net == 'hypothesis': self.net = HypothesisNet(self.args) # picks which parameters to train and which not to train self.n_params = {} self.train_params = [] self.not_train_params = [] logging.info('Training the following parameters:') for k, v in self.net.named_parameters(): # k is name, v is weight found = False # filtering just for the parts that will be trained for part in self.args.train_parts: if part in k: logging.info(f' {k}') self.n_params[k] = (v.shape, v.numel()) self.train_params.append(v) found = True break if not found: self.not_train_params.append(k) logging.info('Not training:') for k in self.not_train_params: logging.info(f' {k}') self.criterion = get_criterion(self.args) self.optimizer = get_optimizer(self.args, self.train_params) self.dset = load_rb(self.args.dataset) self.potential = get_potential(self.args) # if using separate training and test sets, separate them out if not self.args.same_test: np.random.shuffle(self.dset) cutoff = round(.9 * len(self.dset)) self.train_set = self.dset[:cutoff] self.test_set = self.dset[cutoff:] logging.info( f'Using separate training ({cutoff}) and test ({len(self.dset) - cutoff}) sets.' ) else: self.train_set = self.dset self.test_set = self.dset self.log_interval = self.args.log_interval if not self.args.no_log: self.log = self.args.log self.run_id = self.args.log.run_id self.vis_samples = [] self.csv_path = open( os.path.join(self.log.run_dir, f'losses_{self.run_id}.csv'), 'a') self.writer = csv.writer(self.csv_path, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) self.writer.writerow(['ix', 'avg_loss']) self.plot_checkpoint_path = os.path.join( self.log.run_dir, f'checkpoints_{self.run_id}.pkl') self.save_model_path = os.path.join(self.log.run_dir, f'model_{self.run_id}.pth') def log_model(self, ix=0): # saving all checkpoints takes too much space so we just save one model at a time, unless we explicitly specify it if self.args.log_checkpoint_models: self.save_model_path = os.path.join(self.log.checkpoint_dir, f'model_{ix}.pth') elif os.path.exists(self.save_model_path): os.remove(self.save_model_path) torch.save(self.net.state_dict(), self.save_model_path) def log_checkpoint(self, ix, x, y, z, total_loss, avg_loss): self.writer.writerow([ix, avg_loss]) self.csv_path.flush() self.log_model(ix) # we can save individual samples at each checkpoint, that's not too bad space-wise self.vis_samples.append([ix, x, y, z, total_loss, avg_loss]) if os.path.exists(self.plot_checkpoint_path): os.remove(self.plot_checkpoint_path) with open(self.plot_checkpoint_path, 'wb') as f: pickle.dump(self.vis_samples, f) def train_iteration(self, x, y): self.net.reset() self.optimizer.zero_grad() outs = [] total_loss = torch.tensor(0.) # ins is actual input into the network # targets is desired output # outs is output of network if self.args.dset_type == 'goals': ins = [] l_other = {'kl': 0, 'lconf': 0, 'lsim': 0, 'lfprop': 0, 'lp': 0} targets = x cur_idx = torch.zeros(x.shape[0], dtype=torch.long) for j in range(self.args.goals_timesteps): net_out, step_loss, cur_idx, extras = self.run_iter_goal( x, cur_idx) # what we need to record for logging ins.append(extras['in']) outs.append(net_out[-1].detach().numpy()) total_loss += step_loss if 'kl' in extras and extras['kl'] is not None: l_other['kl'] += extras['kl'] if 'lconf' in extras and extras['lconf'] is not None: l_other['lconf'] += extras['lconf'] if 'lsim' in extras and extras['lsim'] is not None: l_other['lsim'] += extras['lsim'] if 'lp' in extras and extras['lp'] is not None: l_other['lp'] += extras['lp'] # if 'lfprop' in extras and extras['lfprop'] is not None: # l_other['lfprop'] += extras['lfprop'] ins = torch.cat(ins) else: ins = x targets = y for j in range(x.shape[1]): net_out, step_loss, extras = self.run_iter_traj( x[:, j], y[:, j]) if np.isnan(step_loss.item()): return -1, (net_out, extras) total_loss += step_loss outs.append(net_out[-1].item()) total_loss.backward() self.optimizer.step() etc = { 'ins': ins, 'targets': targets, 'outs': outs, 'prop': extras['prop'], } etc.update(l_other) if self.args.dset_type == 'goals': etc['indices'] = cur_idx return total_loss, etc # runs an iteration where we want to match a certain trajectory def run_iter_traj(self, x, y): net_in = x.reshape(-1, self.args.L) net_out, extras = self.net(net_in, extras=True) net_target = y.reshape(-1, self.args.Z) step_loss = self.criterion(net_out, net_target) return net_out, step_loss, extras # runs an iteration where we want to hit a certain goal (dynamic input) def run_iter_goal(self, x, indices): x_goal = x[torch.arange(x.shape[0]), indices, :] net_in = x_goal.reshape(-1, self.args.L) net_out, extras = self.net(net_in, extras=True) # the target is actually the input step_loss, new_indices = goals_loss( net_out, x, indices, threshold=self.args.goals_threshold) # it'll be None if we just started, or if we're not doing variational stuff # non-goals related losses # if net_out.shape[0] != 1: # pdb.set_trace() extras['lp'] = self.potential(net_out).sum() step_loss += extras['lp'] if 'kl' in extras and extras['kl'] is not None: step_loss += extras['kl'] if 'lconf' in extras and extras['lconf'] is not None: step_loss += extras['lconf'] if 'lsim' in extras and extras['lsim'] is not None: step_loss += extras['lsim'] # if 'lfprop' in extras and extras['lfprop'] is not None: # step_loss += extras['lfprop'] extras.update({'in': net_in}) return net_out, step_loss, new_indices, extras def test(self, n=0): if n != 0: assert n <= len(self.test_set) batch_idxs = np.random.choice(len(self.test_set), n) batch = [self.test_set[i] for i in batch_idxs] else: batch = self.test_set x, y = get_x_y(batch, self.args.dataset) with torch.no_grad(): self.net.reset() total_loss = torch.tensor(0.) if self.args.dset_type == 'goals': cur_idx = torch.zeros(x.shape[0], dtype=torch.long) for j in range(self.args.goals_timesteps): _, step_loss, cur_idx, _ = self.run_iter_goal(x, cur_idx) total_loss += step_loss else: for j in range(x.shape[1]): _, step_loss, _ = self.run_iter_traj(x[:, j], y[:, j]) total_loss += step_loss etc = {} if self.args.dset_type == 'goals': etc['indices'] = cur_idx return total_loss.item() / len(batch), etc def train(self, ix_callback=None): ix = 0 its_p_epoch = len(self.train_set) // self.args.batch_size logging.info( f'Training set size {len(self.train_set)} | batch size {self.args.batch_size} --> {its_p_epoch} iterations / epoch' ) # for convergence testing max_abs_grads = [] running_min_error = float('inf') running_no_min = 0 running_loss = 0.0 # running_mag = 0.0 ending = False for e in range(self.args.n_epochs): np.random.shuffle(self.train_set) epoch_idx = 0 while epoch_idx < its_p_epoch: epoch_idx += 1 batch = self.train_set[(epoch_idx - 1) * self.args.batch_size:epoch_idx * self.args.batch_size] if len(batch) < self.args.batch_size: break ix += 1 x, y = get_x_y(batch, self.args.dataset) loss, etc = self.train_iteration(x, y) if ix_callback is not None: ix_callback(loss, etc) if loss == -1: logging.info(f'iteration {ix}: is nan. ending') ending = True break running_loss += loss.item() # mag = max([torch.max(torch.abs(p.grad)) for p in self.train_params]) # running_mag += mag if ix % self.log_interval == 0: outs = etc['outs'] z = np.stack(outs).squeeze() # avg of the last 50 trials avg_loss = running_loss / self.args.batch_size / self.log_interval test_loss, test_etc = self.test(n=30) # avg_max_grad = running_mag / self.log_interval log_arr = [ f'iteration {ix}', f'train loss {avg_loss:.3f}', # f'max abs grad {avg_max_grad:.3f}', f'test loss {test_loss:.3f}' ] # calculating average index reached for goals task if self.args.dset_type == 'goals': avg_index = test_etc['indices'].float().mean().item() log_arr.append(f'avg index {avg_index:.3f}') if self.args.net == 'hypothesis': ha = self.net.log_h_yes.get_input() sa = self.net.log_s_yes.get_input() conf = self.net.log_conf.get_input() lconf, lsim, kl, lp = etc['lconf'], etc['lsim'], etc[ 'kl'], etc['lp'] log_arr.append(f'hyp_app {ha:.3f}') log_arr.append(f'sim_app {sa:.3f}') log_arr.append(f'conf {conf:.3f}') log_arr.append(f'lconf {lconf:.3f}') log_arr.append(f'lsim {lsim:.3f}') log_arr.append(f'lp {lp:.3f}') # log_arr.append(f'kl {kl:.3f}') log_str = '\t| '.join(log_arr) logging.info(log_str) if not self.args.no_log: self.log_checkpoint(ix, etc['ins'].numpy(), etc['targets'].numpy(), z, running_loss, avg_loss) running_loss = 0.0 running_mag = 0.0 # convergence based on no avg loss decrease after patience samples if self.args.conv_type == 'patience': if test_loss < running_min_error: running_no_min = 0 running_min_error = test_loss else: running_no_min += self.log_interval if running_no_min > self.args.patience: logging.info( f'iteration {ix}: no min for {args.patience} samples. ending' ) ending = True # elif self.args.conv_type == 'grad': # if avg_max_grad < self.args.grad_threshold: # logging.info(f'iteration {ix}: max absolute grad < {args.grad_threshold}. ending') # ending = True if ending: break logging.info(f'Finished dataset epoch {e+1}') if ending: break if not self.args.no_log: # for later visualization of outputs over timesteps with open( os.path.join(self.log.run_dir, f'checkpoints_{self.run_id}.pkl'), 'wb') as f: pickle.dump(self.vis_samples, f) self.csv_path.close() final_loss, etc = self.test() logging.info( f'END | iterations: {(ix // self.log_interval) * self.log_interval} | test loss: {final_loss}' ) return final_loss, ix