def train(nbatches, npred): model.train() total_loss_i, total_loss_s, total_loss_p = 0, 0, 0 for i in range(nbatches): optimizer.zero_grad() inputs, actions, targets, _, _ = dataloader.get_batch_fm( 'train', npred) inputs = utils.make_variables(inputs) targets = utils.make_variables(targets) actions = Variable(actions) pred, loss_p = model(inputs, actions, targets, z_dropout=opt.z_dropout) loss_p = loss_p[0] loss_i, loss_s = compute_loss(targets, pred) loss = loss_i + loss_s + opt.beta * loss_p # VAEs get NaN loss sometimes, so check for it if not math.isnan(loss.item()): loss.backward(retain_graph=False) if not math.isnan(utils.grad_norm(model).item()): torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) optimizer.step() total_loss_i += loss_i.item() total_loss_s += loss_s.item() total_loss_p += loss_p.item() del inputs, actions, targets total_loss_i /= nbatches total_loss_s /= nbatches total_loss_p /= nbatches return total_loss_i, total_loss_s, total_loss_p
def start(what, nbatches, npred): train = True if what is 'train' else False model.train() model.policy_net.train() n_updates, grad_norm = 0, 0 total_losses = dict( proximity=0, uncertainty=0, lane=0, offroad=0, action=0, policy=0, ) for j in range(nbatches): inputs, actions, targets, ids, car_sizes = dataloader.get_batch_fm( what, npred) pred, actions = planning.train_policy_net_mpur( model, inputs, targets, car_sizes, n_models=10, lrt_z=opt.lrt_z, n_updates_z=opt.z_updates, infer_z=opt.infer_z) pred['policy'] = pred['proximity'] + \ opt.u_reg * pred['uncertainty'] + \ opt.lambda_l * pred['lane'] + \ opt.lambda_a * pred['action'] + \ opt.lambda_o * pred['offroad'] if not math.isnan(pred['policy'].item()): if train: optimizer.zero_grad() pred['policy'].backward() # back-propagation through time! grad_norm += utils.grad_norm(model.policy_net).item() torch.nn.utils.clip_grad_norm_(model.policy_net.parameters(), opt.grad_clip) optimizer.step() for loss in total_losses: total_losses[loss] += pred[loss].item() n_updates += 1 else: print('warning, NaN') # Oh no... Something got quite f****d up! ipdb.set_trace() if j == 0 and opt.save_movies and train: # save videos of normal and adversarial scenarios for b in range(opt.batch_size): state_img = pred['state_img'][b] state_vct = pred['state_vct'][b] utils.save_movie(opt.model_file + f'.mov/sampled/mov{b}', state_img, state_vct, None, actions[b]) del inputs, actions, targets, pred for loss in total_losses: total_losses[loss] /= n_updates if train: print(f'[avg grad norm: {grad_norm / n_updates:.4f}]') return total_losses
def train(nbatches, npred): model.train() total_loss = 0 for i in range(nbatches): optimizer.zero_grad() inputs, actions, targets, _, _ = dataloader.get_batch_fm( 'train', npred) pred, _ = model(inputs, actions, targets, z_dropout=0) pred_cost = cost( pred[0].view(opt.batch_size * opt.npred, 1, 3, opt.height, opt.width), pred[1].view(opt.batch_size * opt.npred, 1, 4)) loss = F.mse_loss(pred_cost.view(opt.batch_size, opt.npred, 2), targets[2]) if not math.isnan(loss.item()): loss.backward(retain_graph=False) if not math.isnan(utils.grad_norm(model).item()): torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) optimizer.step() total_loss += loss.item() del inputs, actions, targets total_loss /= nbatches return total_loss
def adapt(self, indices, inputs, lbls): context_x, context_y = self.mem.fetch(inputs) out = self.forward(inputs) self.criterion.reduce = False loss_out = self.criterion(out, lbls.squeeze(0)) self.criterion.reduce = True gnorms = [] for loss in loss_out: self.zero_grad() loss.backward(retain_graph=True) gnorm = utils.grad_norm(self.parameters()) gnorms.append(gnorm.item()) self.zero_grad() gnorms = np.array(gnorms) gnorms /= gnorms.sum() indices_sampled = np.random.choice(len(gnorms), self.bsz_sampling, p=gnorms) indices_sampled = list(set(indices_sampled)) self.mem.add(inputs[indices_sampled], lbls[indices_sampled]) if context_x is not None and\ context_x is not None: out = self.forward(torch.cat([inputs, context_x], dim=0)) lbl = torch.cat([lbls, context_y], dim=0) else: out = self.forward(inputs) lbl = lbls loss = self.criterion(out, lbl.squeeze(0)) self.nsteps += 1 return loss
def adapt(self, indices, inputs, lbls): context_i, context_x, context_y, context_g = \ self.mem.fetch(inputs) out = self.baby_mlp(inputs) self.criterion.reduce = False loss_out = self.criterion(out, lbls.squeeze(0)) self.criterion.reduce = True gnorms = [] for loss in loss_out: self.baby_mlp.zero_grad() loss.backward(retain_graph=True) gnorm = utils.grad_norm(self.baby_mlp.parameters()) gnorms.append(gnorm) self.baby_mlp.zero_grad() if self.nsteps % self.add_per == 0: self.mem.add(indices, inputs, lbls, gnorms) if context_x is not None and \ context_x is not None: out = self.forward(torch.cat([inputs, context_x], dim=0)) lbl = torch.cat([lbls, context_y], dim=0) else: out = self.forward(inputs) lbl = lbls loss = self.criterion(out, lbl.squeeze(0)) self.nsteps += 1 return loss
def _update(v, runner, criterion): opt_model, opt_q = criterion.opts["model"], criterion.opts["q"] sch_model, sch_q = criterion.schs["model"], criterion.schs["q"] # update q runner.model.requires_grad_(False) runner.q.requires_grad_(True) for i in range(config.get("update", "n_inner_loops")): opt_q.zero_grad() criterion.mid_vals["inner_loss"] = inner_loss = criterion.inner_loss(v).mean() inner_loss.backward() # backward if config.get("update", "gradient_clip", default=False): clip_grad_norm_(runner.q.parameters(), 0.5) opt_q.step() # step sch_q.step() # update model runner.model.requires_grad_(True) runner.q.requires_grad_(True) backup_q_state_dict = runner.q.state_dict() with higher.innerloop_ctx(runner.q, opt_q) as (fq, diffopt_q): criterion.q = fq for i in range(config.get("update", "n_unroll")): inner_loss = criterion.inner_loss(v).mean() diffopt_q.step(inner_loss) criterion.loss_val = loss = criterion.loss(v).mean() opt_model.zero_grad() runner.model.requires_grad_(True) runner.q.requires_grad_(False) loss.backward() if config.get("update", "gradient_clip", default=False): clip_grad_norm_(runner.model.parameters(), 0.5) opt_model.step() sch_model.step() runner.q.load_state_dict(backup_q_state_dict) criterion.q = runner.q criterion.mid_vals["grad_model"] = grad_norm(runner.model) criterion.mid_vals["grad_q"] = grad_norm(runner.q)
def finish_episode(): R = 0 saved_actions = model.saved_actions policy_losses = [] value_losses = [] rewards = [] for r in model.rewards[::-1]: R = r + opt.gamma * R rewards.insert(0, R) rewards = torch.tensor(rewards).cuda() rewards = (rewards - rewards.mean()) / (rewards.std() + eps) for (log_prob, value), r in zip(saved_actions, rewards): reward = r - value.item() policy_losses.append(-log_prob * reward) value_losses.append(F.smooth_l1_loss(value, torch.tensor([r]).cuda())) optimizer.zero_grad() loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum() loss.backward() print(utils.grad_norm(model).item()) optimizer.step() del model.rewards[:] del model.saved_actions[:]
def main(): best_score = 0 args = get_train_args() print(json.dumps(args.__dict__, indent=4)) # Reading the int indexed text dataset train_data = np.load(os.path.join(args.input, args.data + ".train.npy")) train_data = train_data.tolist() dev_data = np.load(os.path.join(args.input, args.data + ".valid.npy")) dev_data = dev_data.tolist() test_data = np.load(os.path.join(args.input, args.data + ".test.npy")) test_data = test_data.tolist() # Reading the vocab file with open(os.path.join(args.input, args.data + '.vocab.pickle'), 'rb') as f: id2w = pickle.load(f) args.id2w = id2w args.n_vocab = len(id2w) # Define Model model = net.Transformer(args) tally_parameters(model) if args.gpu >= 0: model.cuda(args.gpu) print(model) optimizer = optim.TransformerAdamTrainer(model, args) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.model_file): print("=> loading checkpoint '{}'".format(args.model_file)) checkpoint = torch.load(args.model_file) args.start_epoch = checkpoint['epoch'] best_score = checkpoint['best_score'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.model_file, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.model_file)) src_data, trg_data = list(zip(*train_data)) total_src_words = len(list(itertools.chain.from_iterable(src_data))) total_trg_words = len(list(itertools.chain.from_iterable(trg_data))) iter_per_epoch = (total_src_words + total_trg_words) // args.wbatchsize print('Approximate number of iter/epoch =', iter_per_epoch) time_s = time() global_steps = 0 for epoch in range(args.start_epoch, args.epoch): random.shuffle(train_data) train_iter = data.iterator.pool( train_data, args.wbatchsize, key=lambda x: data.utils.interleave_keys(len(x[0]), len(x[1])), batch_size_fn=batch_size_func, random_shuffler=data.iterator.RandomShuffler()) report_stats = utils.Statistics() train_stats = utils.Statistics() valid_stats = utils.Statistics() if args.debug: grad_norm = 0. for num_steps, train_batch in enumerate(train_iter): global_steps += 1 model.train() optimizer.zero_grad() src_iter = list(zip(*train_batch))[0] src_words = len(list(itertools.chain.from_iterable(src_iter))) report_stats.n_src_words += src_words train_stats.n_src_words += src_words in_arrays = utils.seq2seq_pad_concat_convert(train_batch, -1) loss, stat = model(*in_arrays) loss.backward() if args.debug: norm = utils.grad_norm(model.parameters()) grad_norm += norm if global_steps % args.report_every == 0: print("> Gradient Norm: %1.4f" % (grad_norm / (num_steps + 1))) optimizer.step() report_stats.update(stat) train_stats.update(stat) report_stats = report_func(epoch, num_steps, iter_per_epoch, time_s, report_stats, args.report_every) if (global_steps + 1) % args.eval_steps == 0: dev_iter = data.iterator.pool( dev_data, args.wbatchsize, key=lambda x: data.utils.interleave_keys( len(x[0]), len(x[1])), batch_size_fn=batch_size_func, random_shuffler=data.iterator.RandomShuffler()) for dev_batch in dev_iter: model.eval() in_arrays = utils.seq2seq_pad_concat_convert(dev_batch, -1) loss_test, stat = model(*in_arrays) valid_stats.update(stat) print('Train perplexity: %g' % train_stats.ppl()) print('Train accuracy: %g' % train_stats.accuracy()) print('Validation perplexity: %g' % valid_stats.ppl()) print('Validation accuracy: %g' % valid_stats.accuracy()) bleu_score, _ = CalculateBleu(model, dev_data, 'Dev Bleu', batch=args.batchsize // 4, beam_size=args.beam_size, alpha=args.alpha, max_sent=args.max_sent_eval)() if args.metric == "bleu": score = bleu_score elif args.metric == "accuracy": score = valid_stats.accuracy() is_best = score > best_score best_score = max(score, best_score) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_score': best_score, 'optimizer': optimizer.state_dict(), 'opts': args, }, is_best, args.model_file, args.best_model_file) # BLEU score on Dev and Test Data checkpoint = torch.load(args.best_model_file) print("=> loaded checkpoint '{}' (epoch {}, best score {})".format( args.best_model_file, checkpoint['epoch'], checkpoint['best_score'])) model.load_state_dict(checkpoint['state_dict']) print('Dev Set BLEU Score') _, dev_hyp = CalculateBleu(model, dev_data, 'Dev Bleu', batch=args.batchsize // 4, beam_size=args.beam_size, alpha=args.alpha)() save_output(dev_hyp, id2w, args.dev_hyp) print('Test Set BLEU Score') _, test_hyp = CalculateBleu(model, test_data, 'Test Bleu', batch=args.batchsize // 4, beam_size=args.beam_size, alpha=args.alpha)() save_output(test_hyp, id2w, args.test_hyp)
def train(nbatches, npred): model.train() model.policy_net.train() total_loss_c, total_loss_u, total_loss_l, total_loss_a, n_updates, grad_norm = 0, 0, 0, 0, 0, 0 total_loss_policy = 0 for j in range(nbatches): optimizer.zero_grad() inputs, actions, targets, ids, car_sizes = dataloader.get_batch_fm( 'train', npred) inputs = utils.make_variables(inputs) targets = utils.make_variables(targets) pred, actions, pred_adv = planning.train_policy_net_mpur( model, inputs, targets, car_sizes, n_models=10, lrt_z=opt.lrt_z, n_updates_z=opt.z_updates, infer_z=(opt.infer_z == 1)) loss_c = pred[2] # proximity cost loss_l = pred[3] # lane cost loss_u = pred[4] # uncertainty cost loss_a = actions.norm(2, 2).pow(2).mean() # action regularisation loss_policy = loss_c + opt.u_reg * loss_u + opt.lambda_l * loss_l + opt.lambda_a * loss_a if not math.isnan(loss_policy.item()): loss_policy.backward() # back-propagation through time! grad_norm += utils.grad_norm(model.policy_net).item() torch.nn.utils.clip_grad_norm_(model.policy_net.parameters(), opt.grad_clip) optimizer.step() total_loss_c += loss_c.item() # proximity cost total_loss_u += loss_u.item() # uncertainty (reg.) total_loss_a += loss_a.item() # action (reg.) total_loss_l += loss_l.item() # lane cost total_loss_policy += loss_policy.item() # overall total cost n_updates += 1 else: print('warning, NaN') # Oh no... Something got quite f****d up! pdb.set_trace() if j == 0 and opt.save_movies: # save videos of normal and adversarial scenarios for b in range(opt.batch_size): utils.save_movie(opt.model_file + f'.mov/sampled/mov{b}', pred[0][b], pred[1][b], None, actions[b]) if pred_adv[0] is not None: utils.save_movie( opt.model_file + f'.mov/adversarial/mov{b}', pred_adv[0][b], pred_adv[1][b], None, actions[b]) del inputs, actions, targets, pred total_loss_c /= n_updates total_loss_u /= n_updates total_loss_a /= n_updates total_loss_l /= n_updates total_loss_policy /= n_updates print(f'[avg grad norm: {grad_norm / n_updates}]') return total_loss_c, total_loss_l, total_loss_u, total_loss_a, total_loss_policy
def main(): best_score = 0 args = get_train_args() logger = get_logger(args.log_path) logger.info(json.dumps(args.__dict__, indent=4)) # Set seed value torch.manual_seed(args.seed) random.seed(args.seed) if args.gpu: torch.cuda.manual_seed_all(args.seed) # Reading the int indexed text dataset train_data = np.load(os.path.join(args.input, args.data + ".train.npy"), allow_pickle=True) train_data = train_data.tolist() dev_data = np.load(os.path.join(args.input, args.data + ".valid.npy"), allow_pickle=True) dev_data = dev_data.tolist() test_data = np.load(os.path.join(args.input, args.data + ".test.npy"), allow_pickle=True) test_data = test_data.tolist() # Reading the vocab file with open(os.path.join(args.input, args.data + '.vocab.pickle'), 'rb') as f: id2w = pickle.load(f) args.id2w = id2w args.n_vocab = len(id2w) # Define Model model = eval(args.model)(args) model.apply(init_weights) tally_parameters(model) if args.gpu >= 0: model.cuda(args.gpu) logger.info(model) if args.optimizer == 'Noam': optimizer = NoamAdamTrainer(model, args) elif args.optimizer == 'Adam': params = filter(lambda p: p.requires_grad, model.parameters()) optimizer = torch.optim.Adam(params, lr=args.learning_rate, betas=(args.optimizer_adam_beta1, args.optimizer_adam_beta2), eps=args.optimizer_adam_epsilon) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.7, patience=7, verbose=True) elif args.optimizer == 'Yogi': params = filter(lambda p: p.requires_grad, model.parameters()) optimizer = Yogi(params, lr=args.learning_rate, betas=(args.optimizer_adam_beta1, args.optimizer_adam_beta2), eps=args.optimizer_adam_epsilon) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.7, patience=7, verbose=True) if args.fp16: model = FP16_Module(model) optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.static_loss_scale, dynamic_loss_scale=args.dynamic_loss_scale, dynamic_loss_args={'init_scale': 2**16}, verbose=False) ema = ExponentialMovingAverage(decay=args.ema_decay) ema.register(model.state_dict()) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.model_file): logger.info("=> loading checkpoint '{}'".format(args.model_file)) checkpoint = torch.load(args.model_file) args.start_epoch = checkpoint['epoch'] best_score = checkpoint['best_score'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.model_file, checkpoint['epoch'])) else: logger.info("=> no checkpoint found at '{}'".format( args.model_file)) src_data, trg_data = list(zip(*train_data)) total_src_words = len(list(itertools.chain.from_iterable(src_data))) total_trg_words = len(list(itertools.chain.from_iterable(trg_data))) iter_per_epoch = (total_src_words + total_trg_words) // (2 * args.wbatchsize) logger.info('Approximate number of iter/epoch = {}'.format(iter_per_epoch)) time_s = time() global_steps = 0 num_grad_steps = 0 if args.grad_norm_for_yogi and args.optimizer == 'Yogi': args.start_epoch = -1 l2_norm = 0.0 parameters = list( filter(lambda p: p.requires_grad is True, model.parameters())) n_params = sum([p.nelement() for p in parameters]) for epoch in range(args.start_epoch, args.epoch): random.shuffle(train_data) train_iter = data.iterator.pool( train_data, args.wbatchsize, key=lambda x: (len(x[0]), len(x[1])), batch_size_fn=batch_size_fn, random_shuffler=data.iterator.RandomShuffler()) report_stats = utils.Statistics() train_stats = utils.Statistics() if args.debug: grad_norm = 0. for num_steps, train_batch in enumerate(train_iter): global_steps += 1 model.train() if args.grad_accumulator_count == 1: optimizer.zero_grad() elif num_grad_steps % args.grad_accumulator_count == 0: optimizer.zero_grad() src_iter = list(zip(*train_batch))[0] src_words = len(list(itertools.chain.from_iterable(src_iter))) report_stats.n_src_words += src_words train_stats.n_src_words += src_words in_arrays = utils.seq2seq_pad_concat_convert(train_batch, -1) if len(args.multi_gpu) > 1: loss_tuple, stat_tuple = zip( *dp(model, in_arrays, device_ids=args.multi_gpu)) n_total = sum([obj.n_words.item() for obj in stat_tuple]) n_correct = sum([obj.n_correct.item() for obj in stat_tuple]) loss = 0 for l_, s_ in zip(loss_tuple, stat_tuple): loss += l_ * s_.n_words.item() loss /= n_total stat = utils.Statistics(loss=loss.data.cpu() * n_total, n_correct=n_correct, n_words=n_total) else: loss, stat = model(*in_arrays) if args.fp16: optimizer.backward(loss) else: loss.backward() if epoch == -1 and args.grad_norm_for_yogi and args.optimizer == 'Yogi': l2_norm += (utils.grad_norm(model.parameters())**2) / n_params continue num_grad_steps += 1 if args.debug: norm = utils.grad_norm(model.parameters()) grad_norm += norm if global_steps % args.report_every == 0: logger.info("> Gradient Norm: %1.4f" % (grad_norm / (num_steps + 1))) if args.grad_accumulator_count == 1: optimizer.step() ema.apply(model.state_dict(keep_vars=True)) elif num_grad_steps % args.grad_accumulator_count == 0: optimizer.step() ema.apply(model.state_dict(keep_vars=True)) num_grad_steps = 0 report_stats.update(stat) train_stats.update(stat) report_stats = report_func(epoch, num_steps, iter_per_epoch, time_s, report_stats, args.report_every) valid_stats = utils.Statistics() if global_steps % args.eval_steps == 0: with torch.no_grad(): dev_iter = data.iterator.pool( dev_data, args.wbatchsize, key=lambda x: (len(x[0]), len(x[1])), batch_size_fn=batch_size_fn, random_shuffler=data.iterator.RandomShuffler()) for dev_batch in dev_iter: model.eval() in_arrays = utils.seq2seq_pad_concat_convert( dev_batch, -1) if len(args.multi_gpu) > 1: _, stat_tuple = zip(*dp( model, in_arrays, device_ids=args.multi_gpu)) n_total = sum( [obj.n_words.item() for obj in stat_tuple]) n_correct = sum( [obj.n_correct.item() for obj in stat_tuple]) dev_loss = sum([obj.loss for obj in stat_tuple]) stat = utils.Statistics(loss=dev_loss, n_correct=n_correct, n_words=n_total) else: _, stat = model(*in_arrays) valid_stats.update(stat) logger.info('Train perplexity: %g' % train_stats.ppl()) logger.info('Train accuracy: %g' % train_stats.accuracy()) logger.info('Validation perplexity: %g' % valid_stats.ppl()) logger.info('Validation accuracy: %g' % valid_stats.accuracy()) if args.metric == "accuracy": score = valid_stats.accuracy() elif args.metric == "bleu": score, _ = CalculateBleu( model, dev_data, 'Dev Bleu', batch=args.batchsize // 4, beam_size=args.beam_size, alpha=args.alpha, max_sent=args.max_sent_eval)(logger) # Threshold Global Steps to save the model if not (global_steps % 2000): print('saving') is_best = score > best_score best_score = max(score, best_score) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'state_dict_ema': ema.shadow_variable_dict, 'best_score': best_score, 'optimizer': optimizer.state_dict(), 'opts': args, }, is_best, args.model_file, args.best_model_file) if args.optimizer == 'Adam' or args.optimizer == 'Yogi': scheduler.step(score) if epoch == -1 and args.grad_norm_for_yogi and args.optimizer == 'Yogi': optimizer.v_init = l2_norm / (num_steps + 1) logger.info("Initializing Yogi Optimizer (v_init = {})".format( optimizer.v_init)) # BLEU score on Dev and Test Data checkpoint = torch.load(args.best_model_file) logger.info("=> loaded checkpoint '{}' (epoch {}, best score {})".format( args.best_model_file, checkpoint['epoch'], checkpoint['best_score'])) model.load_state_dict(checkpoint['state_dict']) logger.info('Dev Set BLEU Score') _, dev_hyp = CalculateBleu(model, dev_data, 'Dev Bleu', batch=args.batchsize // 4, beam_size=args.beam_size, alpha=args.alpha, max_decode_len=args.max_decode_len)(logger) save_output(dev_hyp, id2w, args.dev_hyp) logger.info('Test Set BLEU Score') _, test_hyp = CalculateBleu(model, test_data, 'Test Bleu', batch=args.batchsize // 4, beam_size=args.beam_size, alpha=args.alpha, max_decode_len=args.max_decode_len)(logger) save_output(test_hyp, id2w, args.test_hyp) # Loading EMA state dict model.load_state_dict(checkpoint['state_dict_ema']) logger.info('Dev Set BLEU Score') _, dev_hyp = CalculateBleu(model, dev_data, 'Dev Bleu', batch=args.batchsize // 4, beam_size=args.beam_size, alpha=args.alpha, max_decode_len=args.max_decode_len)(logger) save_output(dev_hyp, id2w, args.dev_hyp + '.ema') logger.info('Test Set BLEU Score') _, test_hyp = CalculateBleu(model, test_data, 'Test Bleu', batch=args.batchsize // 4, beam_size=args.beam_size, alpha=args.alpha, max_decode_len=args.max_decode_len)(logger) save_output(test_hyp, id2w, args.test_hyp + '.ema')
def start(what, nbatches, npred, split='train', return_per_instance_values=False, threshold=0): train = True if what is 'train' else False evaluate = True if what is 'eval' else False finetune_train = True if split is 'finetune_train' else False finetune_sim = True if split is 'finetune_sim' else False model.train() model.policy_net.train() n_updates, grad_norm = 0, 0 if return_per_instance_values: total_losses = dict( proximity=[], uncertainty=[], lane=[], offroad=[], action=[], policy=[], episode_timestep_pairs=[], ) else: total_losses = dict( proximity=0, uncertainty=0, lane=0, offroad=0, action=0, policy=0, ) if evaluate: episode_cost_progression = {} iterable = range(nbatches) if evaluate or finetune_train: total_instances = dataloader.get_total_instances(split, what) print(f"total_instances in {split}: {total_instances}") iterable = range(0, total_instances, opt.batch_size) # nbatches = None step = 0 for j in iterable: # print("j:",j,n_updates) # with tqdm(total=len(iterable)) as progress_bar: # start = time.time() if not evaluate: inputs, actions, targets, ids, car_sizes = dataloader.get_batch_fm( split, npred, cuda=(True if torch.cuda.is_available() and not opt.no_cuda else False), all_batches=(True if finetune_train else False)) else: e_index, inputs, actions, targets, ids, car_sizes = dataloader.get_batch_fm( split, npred, return_episode_index=True, cuda=(True if torch.cuda.is_available() and not opt.no_cuda else False), all_batches=True if finetune_train else False, randomize=(True if (finetune_train or finetune_sim) else False)) # print(np.unique(e_index))#, type(e_index)) if -1 in e_index[:, 0]: print("breaking now") break pred, actions = planning.train_policy_net_mpur( model, inputs, targets, car_sizes, n_models=10, lrt_z=opt.lrt_z, n_updates_z=opt.z_updates, infer_z=opt.infer_z, no_cuda=opt.no_cuda, return_per_instance_values=(True if evaluate else False)) # print("costs: ",pred["proximity"].shape, pred['lane'].shape, pred['action'].shape, pred['offroad'].shape) pred['policy'] = pred['proximity'] + \ opt.u_reg * pred['uncertainty'] + \ opt.lambda_l * pred['lane'] + \ opt.lambda_a * pred['action'] + \ opt.lambda_o * pred['offroad'] # print(torch.mean(pred['policy'])) # print(pred['policy'].shape) # print("time for loading batches and forward pass: "******"episode" not in key } if instance_loss >= threshold: episode_index, timestep = e_index[b_i] for loss in total_losses: if loss != 'episode_timestep_pairs': total_losses[loss].append( torch.mean(pred[loss][b_i]).detach().cpu()) episode_cost_progression[ e_index[b_i][0]][loss].append( torch.mean( pred[loss][b_i]).detach().cpu()) else: total_losses[loss].append( [episode_index, timestep, instance_loss]) # print(type(dataloader.finetune_dict)) # print(dataloader.finetune_dict) # if episode_index not in dataloader.finetune_dict: # dataloader.finetune_dict[episode_index] = [] # nframes = opt.npred + opt.ncond # min_range = max(0,timestep-nframes) # max_range = min(len(dataloader.images[episode_index]),timestep+100)-50 # for frame_index in range(min_range,max_range,opt.finetune_nframes_overlap): # if frame_index not in dataloader.finetune_dict[episode_index]: # dataloader.finetune_dict[episode_index].append(frame_index) # else: # # print(type(inputs),type(inputs[0]),type(inputs[0][b_i])) # if finetune_inputs["inputs"]: # for f_i, input_i in enumerate(finetune_inputs["inputs"]): # finetune_inputs["inputs"][input_i].append(inputs[f_i][b_i:b_i+1]) # for f_i, input_i in enumerate(finetune_inputs["targets"]): # finetune_inputs["targets"][input_i].append(targets[f_i][b_i:b_i+1]) # else: # finetune_inputs["inputs"] = {f_i : [inputs[f_i][b_i:b_i+1]] for f_i in range(len(inputs))} # finetune_inputs["targets"] = {f_i : [targets[f_i][b_i:b_i+1]] for f_i in range(len(targets))} # # print(type(car_sizes), len(car_sizes), type(car_sizes[0]), len(car_sizes[0])) # finetune_inputs["car_sizes"].append(car_sizes[b_i:b_i+1]) # if len(finetune_inputs["car_sizes"]) == opt.batch_size: # # print([torch.cat(finetune_inputs["inputs"][input_i]).shape for input_i in finetune_inputs["inputs"]]) # # for row in inputs: # # print(row.shape) # # print([torch.cat(finetune_inputs["targets"][input_i]).shape for input_i in finetune_inputs["targets"]]) # # for row in targets: # # print(row.shape) # finetune_inputs["inputs"] = [torch.cat(finetune_inputs["inputs"][input_i]) for input_i in finetune_inputs["inputs"]] # finetune_inputs["targets"] = [torch.cat(finetune_inputs["targets"][input_i]) for input_i in finetune_inputs["targets"]] # finetune_inputs["car_sizes"] = torch.cat(finetune_inputs["car_sizes"]) # # t1,t2,t3 = finetune_inputs["targets"] # # print(type(targets), len(targets), type(targets[0]), len(targets[0])) # # print(type(finetune_inputs["targets"]),len(finetune_inputs["targets"]),type(finetune_inputs["targets"][0]), len(finetune_inputs["targets"][0])) # pred, actions = planning.train_policy_net_mpur( # model, finetune_inputs["inputs"], finetune_inputs["targets"], finetune_inputs["car_sizes"], n_models=10, lrt_z=opt.lrt_z, n_updates_z=opt.z_updates, infer_z=opt.infer_z, no_cuda=opt.no_cuda ) # pred['policy'] = pred['proximity'] + \ # opt.u_reg * pred['uncertainty'] + \ # opt.lambda_l * pred['lane'] + \ # opt.lambda_a * pred['action'] + \ # opt.lambda_o * pred['offroad'] # for loss in finetune_losses: # finetune_losses[loss] += pred[loss] # optimizer.zero_grad() # pred['policy'].backward() # back-propagation through time! # grad_norm += utils.grad_norm(model.policy_net).item() # torch.nn.utils.clip_grad_norm_(model.policy_net.parameters(), opt.grad_clip) # optimizer.step() # n_updates += 1 # print(f"update no: {n_updates}") # finetune_inputs["inputs"] = {} # finetune_inputs["car_sizes"] = [] # finetune_inputs["targets"] = {} else: print('warning, NaN') # Oh no... Something got quite f****d up! ipdb.set_trace() if j == 0 and opt.save_movies and train: # save videos of normal and adversarial scenarios for b in range(opt.batch_size): state_img = pred['state_img'][b] state_vct = pred['state_vct'][b] utils.save_movie(opt.model_file + f'.mov/sampled/mov{b}', state_img, state_vct, None, actions[b]) # step += len(actions) # progress_bar.update(step) del inputs, actions, targets, pred if n_updates == nbatches and not evaluate: # del dataloader[split] break # print("time for saving loss values for calc stats later: ", time.time() - start) if not evaluate: for loss in total_losses: total_losses[loss] /= n_updates # print(total_losses) if train or finetune_train or finetune_sim: print(f'[avg grad norm: {grad_norm / n_updates:.4f}]') if evaluate: pickle.dump( episode_cost_progression, open("policy_loss_stats/episode_cost_progression.pkl", 'wb+')) # if finetune: # return finetune_losses # print("final j value: ", j) return total_losses
def train(self): start_t = time.time() print(f'Training started at {datetime.now()}') print(f'Total number of batches: {len(self.data_loader_train)}') best_valid_loss, best_train_epoch_loss, best_roc_auc = 10, 10, 0 best_step_train_loss, best_step_valid_loss, best_step_valid_roc = 0, 0, 0 drop_counter = 0 loss_fn = self.model.loss() for epoch in range(self.config.num_epochs): epoch_loss = 0 self.model.train() ctr = 0 for ctr, (audio, target, fname) in enumerate(self.data_loader_train): #ctr += 1 drop_counter += 1 audio = audio.to(self.device) target = target.to(self.device) # Time-frequency transform if self.transforms is not None: audio = self.transforms(audio) # predict out = self.model(audio) loss = loss_fn(out, target) # back propagation self.optimizer.zero_grad() loss.backward() if self.config.clip_grad > 0: clip_grad_norm_(self.model.parameters(), self.config.clip_grad) self.optimizer.step() epoch_loss += loss.item() # print log if (ctr) % self.config.print_every == 0: print( "[%s] Epoch [%d/%d] Iter [%d/%d] train loss: %.4f Elapsed: %s" % (datetime.now().strftime('%Y-%m-%d %H:%M:%S'), epoch + 1, self.config.num_epochs, ctr, len(self.data_loader_train), loss.item(), timedelta(seconds=time.time() - start_t))) if self.writer is not None: step = epoch * len(self.data_loader_train) + ctr self.writer.add_scalar('loss', loss.item(), step) self.writer.add_scalar( 'learning_rate', self.optimizer.param_groups[0]['lr'], step) self.writer.add_scalar( 'grad_norm', utils.grad_norm(self.model.parameters()), step) del audio, target epoch_loss = epoch_loss / len(self.data_loader_train) # validation valid_loss, scores, y_true, y_pred = self._validation( start_t, epoch) if self.scheduler is not None: if self.config.scheduler == 'plateau': self.scheduler.step(valid_loss) else: self.scheduler.step() # Log validation if self.writer is not None: step = epoch * len(self.data_loader_train) + ctr self.writer.add_scalar('valid_loss', valid_loss, step) self.writer.add_scalar('valid_roc_auc_macro', scores['roc_auc_macro'], step) if not self.config.debug_mode: self.writer.add_figure( 'valid_class', utils.compare_predictions(y_true, y_pred, filepath=None), step) # Save model, with respect to validation loss if valid_loss < best_valid_loss: # print('best model: %4f' % valid_loss) best_step_valid_loss = drop_counter best_valid_loss = valid_loss torch.save( self.model.state_dict(), os.path.join(self.config.checkpoint_dir, 'best_model_valid_loss.pth')) # Save model, with respect to validation roc_auc if scores['roc_auc_macro'] > best_roc_auc: best_step_valid_roc = drop_counter best_roc_auc = scores['roc_auc_macro'] torch.save( self.model.state_dict(), os.path.join(self.config.checkpoint_dir, 'best_model_valid_roc.pth')) # Save best model according to training loss if epoch_loss < best_train_epoch_loss: best_step_train_loss = drop_counter best_train_epoch_loss = epoch_loss torch.save( self.model.state_dict(), os.path.join(self.config.checkpoint_dir, 'best_model_train.pth')) print("{} Training finished. ----------------------- Elapsed: {}". format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), timedelta(seconds=time.time() - start_t))) print( "Best step (validation loss) = {} . ".format(best_step_valid_loss)) print("Best step (validation roc_auc) = {} .".format( best_step_valid_roc)) print("Best step (training loss) = {} .".format(best_step_train_loss)) # Save last model torch.save( self.model.state_dict(), os.path.join(self.config.checkpoint_dir, 'best_model_final.pth'))