if __name__ == '__main__': parser = build_flags() parser.add_argument('--data_path') parser.add_argument('--same_data_norm', action='store_true') parser.add_argument('--no_data_norm', action='store_true') parser.add_argument('--error_out_name', default='prediction_errors_%dstep.npy') parser.add_argument('--prior_variance', type=float, default=5e-5) parser.add_argument('--test_burn_in_steps', type=int, default=10) parser.add_argument('--error_suffix') parser.add_argument('--subject_ind', type=int, default=-1) args = parser.parse_args() params = vars(args) misc.seed(args.seed) params['num_vars'] = 3 params['input_size'] = 4 params['input_time_steps'] = 50 params['nll_loss_type'] = 'gaussian' train_data = SmallSynthData(args.data_path, 'train', params) val_data = SmallSynthData(args.data_path, 'val', params) model = model_builder.build_model(params) if args.mode == 'train': with train_utils.build_writers(args.working_dir) as (train_writer, val_writer): train.train(model, train_data, val_data, params, train_writer, val_writer)
def train(model, train_data, val_data, params, train_writer, val_writer): gpu = params.get('gpu', False) batch_size = params.get('batch_size', 1000) val_batch_size = params.get('val_batch_size', batch_size) if val_batch_size is None: val_batch_size = batch_size accumulate_steps = params.get('accumulate_steps') training_scheduler = params.get('training_scheduler', None) q_training_scheduler = params.get('training_scheduler', None) num_epochs = params.get('num_epochs', 100) val_interval = params.get('val_interval', 1) val_start = params.get('val_start', 0) clip_grad = params.get('clip_grad', None) clip_grad_norm = params.get('clip_grad_norm', None) normalize_nll = params.get('normalize_nll', False) normalize_kl = params.get('normalize_kl', False) tune_on_nll = params.get('tune_on_nll', False) verbose = params.get('verbose', False) val_teacher_forcing = params.get('val_teacher_forcing', False) continue_training = params.get('continue_training', False) train_data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True) val_data_loader = DataLoader(val_data, batch_size=val_batch_size) lr = params['lr'] wd = params.get('wd', 0.) mom = params.get('mom', 0.) # don't send q_net params to policy optimizer for p in model.decoder.q_net.parameters(): p.requires_grad = False for p in model.Q_graph.parameters(): p.requires_grad = False model_params = [param for param in model.parameters() if param.requires_grad] if params.get('use_adam', False): opt = torch.optim.Adam(model_params, lr=lr, weight_decay=wd) q_opt = torch.optim.Adam(list(model.decoder.q_net.parameters()) + list(model.Q_graph.parameters()), lr=lr, weight_decay=wd) else: opt = torch.optim.SGD(model_params, lr=lr, weight_decay=wd, momentum=mom) q_opt = torch.optim.SGD(list(model.decoder.q_net.parameters()) + list(model.Q_graph.parameters()), lr=lr, weight_decay=wd, momentum=mom) working_dir = params['working_dir'] best_path = os.path.join(working_dir, 'best_model') checkpoint_dir = os.path.join(working_dir, 'model_checkpoint') training_path = os.path.join(working_dir, 'training_checkpoint') if continue_training: print("RESUMING TRAINING") model.load(checkpoint_dir) train_params = torch.load(training_path) start_epoch = train_params['epoch'] opt.load_state_dict(train_params['optimizer']) q_opt.load_state_dict(train_params['q_optimizer']) best_val_result = train_params['best_val_result'] best_val_epoch = train_params['best_val_epoch'] print("STARTING EPOCH: ",start_epoch) else: start_epoch = 1 best_val_epoch = -1 best_val_result = 10000000 training_scheduler = train_utils.build_scheduler(opt, params) q_training_scheduler = train_utils.build_scheduler(q_opt, params) end = start = 0 misc.seed(1) for epoch in range(start_epoch, num_epochs+1): print("EPOCH", epoch, (end-start)) model.train() model.train_percent = epoch / num_epochs start = time.time() for batch_ind, batch in enumerate(train_data_loader): inputs = batch['inputs'] if gpu: inputs = inputs.cuda(non_blocking=True) # critic training for p in model.decoder.q_net.parameters(): p.requires_grad = True for p in model.Q_graph.parameters(): p.requires_grad = True q_opt.zero_grad() opt.zero_grad() for _ in range(1): loss_critic, loss_nll = model.calculate_loss_q(inputs, is_train=True, return_logits=True) loss_critic.backward() print(loss_critic,"crit","0.9",epoch) print(loss_nll,"nll") torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) q_opt.step() q_opt.zero_grad() opt.zero_grad() for p in model.decoder.q_net.parameters(): p.requires_grad = False for p in model.Q_graph.parameters(): p.requires_grad = False # Finally, update Q_target networks by polyak averaging. # We only do it for the q_net, as we don't use the target policy with torch.no_grad(): polyak=0.9 for p, p_targ in zip(model.decoder.q_net.parameters(), model.decoder_targ.q_net.parameters()): # NB: We use an in-place operations "mul_", "add_" to update target # params, as opposed to "mul" and "add", which would make new tensors. p_targ.data.mul_(polyak) p_targ.data.add_((1 - polyak) * p.data) for p, p_targ in zip(model.Q_graph.parameters(), model.Q_graph_targ.parameters()): # NB: We use an in-place operations "mul_", "add_" to update target # params, as opposed to "mul" and "add", which would make new tensors. p_targ.data.mul_(polyak) p_targ.data.add_((1 - polyak) * p.data) # policy training q_opt.zero_grad() opt.zero_grad() for _ in range(2): loss, loss_policy, loss_kl, logits, _ = model.calculate_loss_pi(inputs, is_train=True, return_logits=True) loss.backward() print(loss, "pol") torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() opt.zero_grad() q_opt.zero_grad() if training_scheduler is not None: training_scheduler.step() q_training_scheduler.step() if train_writer is not None: train_writer.add_scalar('loss', loss.item(), global_step=epoch) if normalize_nll: train_writer.add_scalar('NLL', loss_nll.mean().item(), global_step=epoch) else: train_writer.add_scalar('NLL', loss_nll.mean().item()/(inputs.size(1)*inputs.size(2)), global_step=epoch) train_writer.add_scalar("KL Divergence", loss_kl.mean().item(), global_step=epoch) model.eval() opt.zero_grad() total_nll = 0 total_kl = 0 if verbose: print("COMPUTING VAL LOSSES") with torch.no_grad(): for batch_ind, batch in enumerate(val_data_loader): inputs = batch['inputs'] if gpu: inputs = inputs.cuda(non_blocking=True) loss_critic, loss_nll = model.calculate_loss_q(inputs, is_train=False, teacher_forcing=val_teacher_forcing, return_logits=True) loss, loss_policy, loss_kl, logits, _ = model.calculate_loss_pi(inputs, is_train=False, teacher_forcing=val_teacher_forcing, return_logits=True) total_kl += loss_kl.sum().item() total_nll += loss_nll.sum().item() if verbose: print("\tVAL BATCH %d of %d: %f, %f"%(batch_ind+1, len(val_data_loader), loss_nll.mean(), loss_kl.mean())) total_kl /= len(val_data) total_nll /= len(val_data) total_loss = model.kl_coef*total_kl + total_nll #TODO: this is a thing you fixed if val_writer is not None: val_writer.add_scalar('loss', total_loss, global_step=epoch) val_writer.add_scalar("NLL", total_nll, global_step=epoch) val_writer.add_scalar("KL Divergence", total_kl, global_step=epoch) if tune_on_nll: tuning_loss = total_nll else: tuning_loss = total_loss if tuning_loss < best_val_result: best_val_epoch = epoch best_val_result = tuning_loss print("BEST VAL RESULT. SAVING MODEL...") model.save(best_path) model.save(checkpoint_dir) torch.save({ 'epoch':epoch+1, 'optimizer':opt.state_dict(), 'q_optimizer':q_opt.state_dict(), 'best_val_result':best_val_result, 'best_val_epoch':best_val_epoch, }, training_path) print("EPOCH %d EVAL: "%epoch) print("\tCURRENT VAL LOSS: %f"%tuning_loss) print("\tBEST VAL LOSS: %f"%best_val_result) print("\tBEST VAL EPOCH: %d"%best_val_epoch) end = time.time()
def train(model, train_data, val_data, params, train_writer, val_writer): gpu = params.get('gpu', False) batch_size = params.get('batch_size', 1000) sub_batch_size = params.get('sub_batch_size') if sub_batch_size is None: sub_batch_size = batch_size val_batch_size = params.get('val_batch_size', batch_size) if val_batch_size is None: val_batch_size = batch_size accumulate_steps = params.get('accumulate_steps', 1) training_scheduler = params.get('training_scheduler', None) num_epochs = params.get('num_epochs', 100) val_interval = params.get('val_interval', 1) val_start = params.get('val_start', 0) clip_grad = params.get('clip_grad', None) clip_grad_norm = params.get('clip_grad_norm', None) normalize_nll = params.get('normalize_nll', False) normalize_kl = params.get('normalize_kl', False) tune_on_nll = params.get('tune_on_nll', False) verbose = params.get('verbose', False) val_teacher_forcing = params.get('val_teacher_forcing', False) collate_fn = params.get('collate_fn', None) continue_training = params.get('continue_training', False) normalize_inputs = params['normalize_inputs'] num_decoder_samples = 1 train_data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn) print("NUM BATCHES: ", len(train_data_loader)) val_data_loader = DataLoader(val_data, batch_size=val_batch_size, collate_fn=collate_fn) lr = params['lr'] wd = params.get('wd', 0.) mom = params.get('mom', 0.) model_params = [ param for param in model.parameters() if param.requires_grad ] if params.get('use_adam', False): opt = torch.optim.Adam(model_params, lr=lr, weight_decay=wd) else: opt = torch.optim.SGD(model_params, lr=lr, weight_decay=wd, momentum=mom) working_dir = params['working_dir'] best_path = os.path.join(working_dir, 'best_model') checkpoint_dir = os.path.join(working_dir, 'model_checkpoint') training_path = os.path.join(working_dir, 'training_checkpoint') if continue_training: print("RESUMING TRAINING") model.load(checkpoint_dir) train_params = torch.load(training_path) start_epoch = train_params['epoch'] opt.load_state_dict(train_params['optimizer']) best_val_result = train_params['best_val_result'] best_val_epoch = train_params['best_val_epoch'] model.steps = train_params['step'] print("STARTING EPOCH: ", start_epoch) else: start_epoch = 1 best_val_epoch = -1 best_val_result = 10000000 training_scheduler = train_utils.build_scheduler(opt, params) end = start = 0 misc.seed(1) for epoch in range(start_epoch, num_epochs + 1): model.epoch = epoch print("EPOCH", epoch, (end - start)) model.train_percent = epoch / num_epochs start = time.time() for batch_ind, batch in enumerate(train_data_loader): model.train() inputs = batch['inputs'] masks = batch.get('masks', None) node_inds = batch.get('node_inds', None) graph_info = batch.get('graph_info', None) if gpu: inputs = inputs.cuda(non_blocking=True) if masks is not None: masks = masks.cuda(non_blocking=True) args = {'is_train': True, 'return_logits': True} sub_steps = len(range(0, batch_size, sub_batch_size)) for sub_batch_ind in range(0, batch_size, sub_batch_size): sub_inputs = inputs[sub_batch_ind:sub_batch_ind + sub_batch_size] for sample in range(num_decoder_samples): if normalize_inputs: if masks is not None: normalized_inputs = model.normalize_inputs( inputs[:, :-1], masks[:, :-1]) else: normalized_inputs = model.normalize_inputs( inputs[:, :-1]) args['normalized_inputs'] = normalized_inputs[ sub_batch_ind:sub_batch_ind + sub_batch_size] if masks is not None: sub_masks = masks[sub_batch_ind:sub_batch_ind + sub_batch_size] sub_node_inds = node_inds[sub_batch_ind:sub_batch_ind + sub_batch_size] sub_graph_info = graph_info[ sub_batch_ind:sub_batch_ind + sub_batch_size] loss, loss_nll, loss_kl, logits, _ = model.calculate_loss( sub_inputs, sub_masks, sub_node_inds, sub_graph_info, **args) else: loss, loss_nll, loss_kl, logits, _ = model.calculate_loss( sub_inputs, **args) loss = loss / (sub_steps * accumulate_steps * num_decoder_samples) loss.backward() if verbose: tmp_batch_ind = batch_ind * sub_steps + sub_batch_ind + 1 tmp_total_batch = len(train_data_loader) * sub_steps print("\tBATCH %d OF %d: %f, %f, %f" % (tmp_batch_ind, tmp_total_batch, loss.item(), loss_nll.mean().item(), loss_kl.mean().item())) if accumulate_steps == -1 or (batch_ind + 1) % accumulate_steps == 0: if verbose and accumulate_steps > 0: print("\tUPDATING WEIGHTS") if clip_grad is not None: nn.utils.clip_grad_value_(model.parameters(), clip_grad) elif clip_grad_norm is not None: nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) opt.step() model.steps += 1 opt.zero_grad() if accumulate_steps > 0 and accumulate_steps > len( train_data_loader) - batch_ind - 1: break if training_scheduler is not None: training_scheduler.step() if train_writer is not None: train_writer.add_scalar( 'loss', loss.item() * (sub_steps * accumulate_steps * num_decoder_samples), global_step=epoch) if normalize_nll: train_writer.add_scalar('NLL', loss_nll.mean().item(), global_step=epoch) else: train_writer.add_scalar('NLL', loss_nll.mean().item() / (inputs.size(1) * inputs.size(2)), global_step=epoch) train_writer.add_scalar("KL Divergence", loss_kl.mean().item(), global_step=epoch) if ((epoch + 1) % val_interval != 0): end = time.time() continue model.eval() opt.zero_grad() total_nll = 0 total_kl = 0 if verbose: print("COMPUTING VAL LOSSES") with torch.no_grad(): for batch_ind, batch in enumerate(val_data_loader): inputs = batch['inputs'] masks = batch.get('masks', None) node_inds = batch.get('node_inds', None) graph_info = batch.get('graph_info', None) if gpu: inputs = inputs.cuda(non_blocking=True) if masks is not None: masks = masks.cuda(non_blocking=True) if masks is not None: loss, loss_nll, loss_kl, logits, _ = model.calculate_loss( inputs, masks, node_inds, graph_info, is_train=False, teacher_forcing=val_teacher_forcing, return_logits=True) else: loss, loss_nll, loss_kl, logits, _ = model.calculate_loss( inputs, is_train=False, teacher_forcing=val_teacher_forcing, return_logits=True) total_kl += loss_kl.sum().item() total_nll += loss_nll.sum().item() if verbose: print("\tVAL BATCH %d of %d: %f, %f" % (batch_ind + 1, len(val_data_loader), loss_nll.mean(), loss_kl.mean())) total_kl /= len(val_data) total_nll /= len(val_data) total_loss = model.kl_coef * total_kl + total_nll #TODO: this is a thing you fixed #total_loss = total_kl + total_nll if val_writer is not None: val_writer.add_scalar('loss', total_loss, global_step=epoch) val_writer.add_scalar("NLL", total_nll, global_step=epoch) val_writer.add_scalar("KL Divergence", total_kl, global_step=epoch) if tune_on_nll: tuning_loss = total_nll else: tuning_loss = total_loss if tuning_loss < best_val_result: best_val_epoch = epoch best_val_result = tuning_loss print("BEST VAL RESULT. SAVING MODEL...") model.save(best_path) model.save(checkpoint_dir) torch.save( { 'epoch': epoch + 1, 'optimizer': opt.state_dict(), 'best_val_result': best_val_result, 'best_val_epoch': best_val_epoch, 'step': model.steps, }, training_path) print("EPOCH %d EVAL: " % epoch) print("\tCURRENT VAL LOSS: %f" % tuning_loss) print("\tBEST VAL LOSS: %f" % best_val_result) print("\tBEST VAL EPOCH: %d" % best_val_epoch) end = time.time()
def train(model, train_data, val_data, params, train_writer, val_writer): gpu = params.get('gpu', False) batch_size = params.get('batch_size', 1000) val_batch_size = params.get('val_batch_size', batch_size) if val_batch_size is None: val_batch_size = batch_size accumulate_steps = params.get('accumulate_steps') training_scheduler = params.get('training_scheduler', None) num_epochs = params.get('num_epochs', 100) val_interval = params.get('val_interval', 1) val_start = params.get('val_start', 0) clip_grad = params.get('clip_grad', None) clip_grad_norm = params.get('clip_grad_norm', None) normalize_nll = params.get('normalize_nll', False) normalize_kl = params.get('normalize_kl', False) tune_on_nll = params.get('tune_on_nll', False) verbose = params.get('verbose', False) val_teacher_forcing = params.get('val_teacher_forcing', False) continue_training = params.get('continue_training', False) train_data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True) val_data_loader = DataLoader(val_data, batch_size=val_batch_size) lr = params['lr'] wd = params.get('wd', 0.) mom = params.get('mom', 0.) disc = dnriModels.DNRI_Disc(params).cuda() disc_params = [param for param in disc.parameters() if param.requires_grad] model_params = [ param for param in model.parameters() if param.requires_grad ] if params.get('use_adam', False): disc_opt = torch.optim.Adam(disc_params, lr=lr, weight_decay=wd) opt = torch.optim.Adam(model_params, lr=lr, weight_decay=wd) else: disc_opt = torch.optim.SGD(disc_params, lr=lr, weight_decay=wd, momentum=mom) opt = torch.optim.SGD(model_params, lr=lr, weight_decay=wd, momentum=mom) working_dir = params['working_dir'] best_path = os.path.join(working_dir, 'best_model') checkpoint_dir = os.path.join(working_dir, 'model_checkpoint') training_path = os.path.join(working_dir, 'training_checkpoint') if continue_training: print("RESUMING TRAINING") model.load(checkpoint_dir) train_params = torch.load(training_path) start_epoch = train_params['epoch'] opt.load_state_dict(train_params['optimizer']) best_val_result = train_params['best_val_result'] best_val_epoch = train_params['best_val_epoch'] print("STARTING EPOCH: ", start_epoch) else: start_epoch = 1 best_val_epoch = -1 best_val_result = 10000000 disc_training_scheduler = train_utils.build_scheduler(disc_opt, params) training_scheduler = train_utils.build_scheduler(opt, params) end = start = 0 misc.seed(1) CEloss = nn.CrossEntropyLoss() for epoch in range(start_epoch, num_epochs + 1): #print("EPOCH", epoch, (end-start)) disc.train() disc.train_percent = epoch / num_epochs model.train() model.train_percent = epoch / num_epochs start = time.time() for batch_ind, batch in enumerate(train_data_loader): inputs = batch['inputs'] if gpu: inputs = inputs.cuda(non_blocking=True) loss, loss_nll, loss_kl, logits, all_Preds = model.calculate_loss( inputs, is_train=True, return_logits=True, disc=disc) x1_x2_pairs = torch.cat( [all_Preds[:, :-1, :, :], all_Preds[:, 1:, :, :]], dim=-1).detach().clone().cuda() discrim_pred = disc(x1_x2_pairs) discrim_prob = nn.functional.softmax(discrim_pred, dim=-1) disc_logits = logits[:, :-1, :, :].argmax(dim=-1) if batch_ind == 0 and (epoch % 99) == 0: print("disc_prob/encoder_logits", discrim_prob.shape, disc_logits.shape) for i in range(logits.shape[1] - 1): print( "prob/trgt", discrim_prob[1, i, :].cpu().detach().argmax( dim=-1).numpy(), disc_logits[1, i, :].cpu().detach().numpy()) discrim_prob = discrim_prob.view(-1, logits.shape[-1]) disc_logits = disc_logits.flatten().detach().clone().long() valid_idx = disc_logits.nonzero().view(-1) valid_idx0 = torch.randint(discrim_prob.shape[0], (valid_idx.shape[0], )) final_disc_prob = torch.cat( [discrim_prob[valid_idx], discrim_prob[valid_idx0]], dim=0) final_disc_logits = torch.cat( [disc_logits[valid_idx], disc_logits[valid_idx0]], dim=0) disc_loss = CEloss(final_disc_prob, final_disc_logits) disc_loss.backward() disc_opt.step() disc_opt.zero_grad() loss.backward() if verbose: print("\tBATCH %d OF %d: %f, %f, %f" % (batch_ind + 1, len(train_data_loader), loss.item(), loss_nll.mean().item(), loss_kl.mean().item())) if accumulate_steps == -1 or (batch_ind + 1) % accumulate_steps == 0: if verbose and accumulate_steps > 0: print("\tUPDATING WEIGHTS") if clip_grad is not None: nn.utils.clip_grad_value_(model.parameters(), clip_grad) elif clip_grad_norm is not None: nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) opt.step() opt.zero_grad() if accumulate_steps > 0 and accumulate_steps > len( train_data_loader) - batch_ind - 1: break if training_scheduler is not None: training_scheduler.step() disc_training_scheduler.step() if train_writer is not None: train_writer.add_scalar('loss', loss.item(), global_step=epoch) if normalize_nll: train_writer.add_scalar('NLL', loss_nll.mean().item(), global_step=epoch) else: train_writer.add_scalar('NLL', loss_nll.mean().item() / (inputs.size(1) * inputs.size(2)), global_step=epoch) train_writer.add_scalar("KL Divergence", loss_kl.mean().item(), global_step=epoch) model.eval() opt.zero_grad() disc_opt.zero_grad() total_nll = 0 total_kl = 0 if verbose: print("COMPUTING VAL LOSSES") with torch.no_grad(): for batch_ind, batch in enumerate(val_data_loader): inputs = batch['inputs'] if gpu: inputs = inputs.cuda(non_blocking=True) loss, loss_nll, loss_kl, logits, _ = model.calculate_loss( inputs, is_train=False, teacher_forcing=val_teacher_forcing, return_logits=True) total_kl += loss_kl.sum().item() total_nll += loss_nll.sum().item() if verbose: print("\tVAL BATCH %d of %d: %f, %f" % (batch_ind + 1, len(val_data_loader), loss_nll.mean(), loss_kl.mean())) total_kl /= len(val_data) total_nll /= len(val_data) total_loss = model.kl_coef * total_kl + total_nll #TODO: this is a thing you fixed if val_writer is not None: val_writer.add_scalar('loss', total_loss, global_step=epoch) val_writer.add_scalar("NLL", total_nll, global_step=epoch) val_writer.add_scalar("KL Divergence", total_kl, global_step=epoch) if tune_on_nll: tuning_loss = total_nll else: tuning_loss = total_loss if tuning_loss < best_val_result: best_val_epoch = epoch best_val_result = tuning_loss #print("BEST VAL RESULT. SAVING MODEL...") model.save(best_path) model.save(checkpoint_dir) torch.save( { 'epoch': epoch + 1, 'optimizer': opt.state_dict(), 'best_val_result': best_val_result, 'best_val_epoch': best_val_epoch, }, training_path) #print("EPOCH %d EVAL: "%epoch) #print("\tCURRENT VAL LOSS: %f"%tuning_loss) #print("\tBEST VAL LOSS: %f"%best_val_result) #print("\tBEST VAL EPOCH: %d"%best_val_epoch) end = time.time()