def main(args): # Init tensorboard writer = SummaryWriter('./runs/' + args.runname + str(args.trialnumber)) model_name = 'VanillaDMM' # Set evaluation log file evaluation_logpath = './logs/{}/evaluation_result.log'.format( model_name.lower()) log_evaluation(evaluation_logpath, 'Evaluation Trial - {}\n'.format(args.trialnumber)) # Constants time_length = 30 input_length_for_pred = 20 pred_length = time_length - input_length_for_pred train_batch_size = 16 valid_batch_size = 1 # For model input_channels = 1 z_channels = 50 emission_channels = [64, 32] transition_channels = 64 encoder_channels = [32, 64] rnn_input_dim = 256 rnn_channels = 128 kernel_size = 3 pred_length = 0 # Device checking use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if use_cuda else "cpu") # Make dataset logging.info("Generate data") train_datapath = args.datapath / 'train' valid_datapath = args.datapath / 'valid' train_dataset = DiffusionDataset(train_datapath) valid_dataset = DiffusionDataset(valid_datapath) # Create data loaders from pickle data logging.info("Generate data loaders") train_dataloader = DataLoader( train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=8) valid_dataloader = DataLoader( valid_dataset, batch_size=valid_batch_size, num_workers=4) # Training parameters width = 100 height = 100 input_dim = width * height # Create model logging.warning("Generate model") logging.warning(input_dim) pred_input_dim = 10 dmm = DMM(input_channels=input_channels, z_channels=z_channels, emission_channels=emission_channels, transition_channels=transition_channels, encoder_channels=encoder_channels, rnn_input_dim=rnn_input_dim, rnn_channels=rnn_channels, kernel_size=kernel_size, height=height, width=width, pred_input_dim=pred_input_dim, num_layers=1, rnn_dropout_rate=0.0, num_iafs=0, iaf_dim=50, use_cuda=use_cuda) # Initialize model logging.info("Initialize model") epochs = args.endepoch learning_rate = 0.0001 beta1 = 0.9 beta2 = 0.999 clip_norm = 10.0 lr_decay = 1.0 weight_decay = 0 adam_params = {"lr": learning_rate, "betas": (beta1, beta2), "clip_norm": clip_norm, "lrd": lr_decay, "weight_decay": weight_decay} adam = ClippedAdam(adam_params) elbo = Trace_ELBO() svi = SVI(dmm.model, dmm.guide, adam, loss=elbo) # saves the model and optimizer states to disk save_model = Path('./checkpoints/' + model_name) def save_checkpoint(epoch): save_dir = save_model / '{}.model'.format(epoch) save_opt_dir = save_model / '{}.opt'.format(epoch) logging.info("saving model to %s..." % save_dir) torch.save(dmm.state_dict(), save_dir) logging.info("saving optimizer states to %s..." % save_opt_dir) adam.save(save_opt_dir) logging.info("done saving model and optimizer checkpoints to disk.") # Starting epoch start_epoch = args.startepoch # loads the model and optimizer states from disk if start_epoch != 0: load_opt = './checkpoints/' + model_name + \ '/e{}-i188-opt-tn{}.opt'.format(start_epoch - 1, args.trialnumber) load_model = './checkpoints/' + model_name + \ '/e{}-i188-tn{}.pt'.format(start_epoch - 1, args.trialnumber) def load_checkpoint(): # assert exists(load_opt) and exists(load_model), \ # "--load-model and/or --load-opt misspecified" logging.info("loading model from %s..." % load_model) dmm.load_state_dict(torch.load(load_model, map_location=device)) # logging.info("loading optimizer states from %s..." % load_opt) # adam.load(load_opt) # logging.info("done loading model and optimizer states.") if load_model != '': logging.info('Load checkpoint') load_checkpoint() # Validation only? validation_only = args.validonly # Train the model if not validation_only: logging.info("Training model") annealing_epochs = 1000 minimum_annealing_factor = 0.2 N_train_size = 3000 N_mini_batches = int(N_train_size / train_batch_size + int(N_train_size % train_batch_size > 0)) for epoch in tqdm(range(start_epoch, epochs), desc='Epoch', leave=True): r_loss_train = 0 dmm.train(True) idx = 0 mov_avg_loss = 0 mov_data_len = 0 for which_mini_batch, data in enumerate(tqdm(train_dataloader, desc='Train', leave=True)): if annealing_epochs > 0 and epoch < annealing_epochs: # compute the KL annealing factor approriate for the current mini-batch in the current epoch min_af = minimum_annealing_factor annealing_factor = min_af + (1.0 - min_af) * \ (float(which_mini_batch + epoch * N_mini_batches + 1) / float(annealing_epochs * N_mini_batches)) else: # by default the KL annealing factor is unity annealing_factor = 1.0 data['observation'] = normalize( data['observation'].unsqueeze(2).to(device)) batch_size, length, _, w, h = data['observation'].shape data_reversed = reverse_sequences(data['observation']) data_mask = torch.ones( batch_size, length, input_channels, w, h).cuda() loss = svi.step(data['observation'], data_reversed, data_mask, annealing_factor) # Running losses mov_avg_loss += loss mov_data_len += batch_size r_loss_train += loss idx += 1 # Average losses train_loss_avg = r_loss_train / (len(train_dataset) * time_length) writer.add_scalar('Loss/train', train_loss_avg, epoch) logging.info("Epoch: %d, Training loss: %1.5f", epoch, train_loss_avg) # # Time to time evaluation if epoch == epochs - 1: for temp_pred_length in [20]: r_loss_valid = 0 r_loss_loc_valid = 0 r_loss_scale_valid = 0 r_loss_latent_valid = 0 dmm.train(False) val_pred_length = temp_pred_length val_pred_input_length = 10 with torch.no_grad(): for i, data in enumerate(tqdm(valid_dataloader, desc='Eval', leave=True)): data['observation'] = normalize( data['observation'].unsqueeze(2).to(device)) batch_size, length, _, w, h = data['observation'].shape data_reversed = reverse_sequences( data['observation']) data_mask = torch.ones( batch_size, length, input_channels, w, h).cuda() pred_tensor = data['observation'][:, :input_length_for_pred, :, :, :] pred_tensor_reversed = reverse_sequences( pred_tensor) pred_tensor_mask = torch.ones( batch_size, input_length_for_pred, input_channels, w, h).cuda() ground_truth = data['observation'][:, input_length_for_pred:, :, :, :] val_nll = svi.evaluate_loss( data['observation'], data_reversed, data_mask) preds, _, loss_loc, loss_scale = do_prediction_rep_inference( dmm, pred_tensor_mask, val_pred_length, val_pred_input_length, data['observation']) ground_truth = denormalize( data['observation'].squeeze().cpu().detach() ) pred_with_input = denormalize( torch.cat( [data['observation'][:, :-val_pred_length, :, :, :].squeeze(), preds.squeeze()], dim=0 ).cpu().detach() ) # Running losses r_loss_valid += val_nll r_loss_loc_valid += loss_loc r_loss_scale_valid += loss_scale # Average losses valid_loss_avg = r_loss_valid / \ (len(valid_dataset) * time_length) valid_loss_loc_avg = r_loss_loc_valid / \ (len(valid_dataset) * val_pred_length * width * height) valid_loss_scale_avg = r_loss_scale_valid / \ (len(valid_dataset) * val_pred_length * width * height) writer.add_scalar('Loss/test', valid_loss_avg, epoch) writer.add_scalar( 'Loss/test_obs', valid_loss_loc_avg, epoch) writer.add_scalar('Loss/test_scale', valid_loss_scale_avg, epoch) logging.info("Validation loss: %1.5f", valid_loss_avg) logging.info("Validation obs loss: %1.5f", valid_loss_loc_avg) logging.info("Validation scale loss: %1.5f", valid_loss_scale_avg) log_evaluation(evaluation_logpath, "Validation obs loss for {}s pred {}: {}\n".format( val_pred_length, args.trialnumber, valid_loss_loc_avg)) log_evaluation(evaluation_logpath, "Validation scale loss for {}s pred {}: {}\n".format( val_pred_length, args.trialnumber, valid_loss_scale_avg)) # Save model if epoch % 50 == 0 or epoch == epochs - 1: torch.save(dmm.state_dict(), args.modelsavepath / model_name / 'e{}-i{}-tn{}.pt'.format(epoch, idx, args.trialnumber)) adam.save(args.modelsavepath / model_name / 'e{}-i{}-opt-tn{}.opt'.format(epoch, idx, args.trialnumber)) # Last validation after training test_samples_indices = range(100) total_n = 0 if validation_only: r_loss_loc_valid = 0 r_loss_scale_valid = 0 r_loss_latent_valid = 0 dmm.train(False) val_pred_length = args.validpredlength val_pred_input_length = 10 with torch.no_grad(): for i in tqdm(test_samples_indices, desc='Valid', leave=True): # Data processing data = valid_dataset[i] if torch.isnan(torch.sum(data['observation'])): print("Skip {}".format(i)) continue else: total_n += 1 data['observation'] = normalize( data['observation'].unsqueeze(0).unsqueeze(2).to(device)) batch_size, length, _, w, h = data['observation'].shape data_reversed = reverse_sequences(data['observation']) data_mask = torch.ones( batch_size, length, input_channels, w, h).to(device) # Prediction pred_tensor_mask = torch.ones( batch_size, input_length_for_pred, input_channels, w, h).to(device) preds, _, loss_loc, loss_scale = do_prediction_rep_inference( dmm, pred_tensor_mask, val_pred_length, val_pred_input_length, data['observation']) ground_truth = denormalize( data['observation'].squeeze().cpu().detach() ) pred_with_input = denormalize( torch.cat( [data['observation'][:, :-val_pred_length, :, :, :].squeeze(), preds.squeeze()], dim=0 ).cpu().detach() ) # Save samples if i < 5: save_dir_samples = Path('./samples/more_variance_long') with open(save_dir_samples / '{}-gt-test.pkl'.format(i), 'wb') as fout: pickle.dump(ground_truth, fout) with open(save_dir_samples / '{}-vanilladmm-pred-test.pkl'.format(i), 'wb') as fout: pickle.dump(pred_with_input, fout) # Running losses r_loss_loc_valid += loss_loc r_loss_scale_valid += loss_scale r_loss_latent_valid += np.sum((preds.squeeze().detach().cpu().numpy( ) - data['latent'][time_length - val_pred_length:, :, :].detach().cpu().numpy()) ** 2) # Average losses test_samples_indices = range(total_n) print(total_n) valid_loss_loc_avg = r_loss_loc_valid / \ (total_n * val_pred_length * width * height) valid_loss_scale_avg = r_loss_scale_valid / \ (total_n * val_pred_length * width * height) valid_loss_latent_avg = r_loss_latent_valid / \ (total_n * val_pred_length * width * height) logging.info("Validation obs loss for %ds pred VanillaDMM: %f", val_pred_length, valid_loss_loc_avg) logging.info("Validation latent loss: %f", valid_loss_latent_avg) with open('VanillaDMMResult.log', 'a+') as fout: validation_log = 'Pred {}s VanillaDMM: {}\n'.format( val_pred_length, valid_loss_loc_avg) fout.write(validation_log)
class Trainer(object): """ trainer class used to instantiate, train and validate a network """ def __init__(self, args): # argument gathering self.rand_seed = args.rand_seed self.dev_num = args.dev_num self.cuda = args.cuda self.n_epoch = args.n_epoch self.batch_size = args.batch_size self.lr = args.lr self.beta1 = args.beta1 self.beta2 = args.beta2 self.wd = args.wd self.cn = args.cn self.lr_decay = args.lr_decay self.kl_ae = args.kl_ae self.maf = args.maf self.dropout = args.dropout self.ckpt_f = args.ckpt_f self.load_opt = args.load_opt self.load_model = args.load_model self.save_opt = args.save_opt self.save_model = args.save_model self.maxlen = args.maxlen # setup logging self.log = utils.get_logger(args.log) self.log(args) def _validate(self, val_iter): """ freezes training and validates on the network with a validation set """ # freeze training self.dmm.rnn.eval() val_nll = 0 for ii, batch in enumerate(iter(val_iter)): batch_data = Variable(batch.text[0].to(self.device)) seqlens = Variable(batch.text[1].to(self.device)) # transpose to [B, seqlen, vocab_size] shape batch_data = torch.t(batch_data) # compute one hot character embedding batch_onehot = nn.functional.one_hot(batch_data, self.vocab_size).float() # flip sequence for rnn batch_reversed = utils.reverse_seq(batch_onehot, seqlens) batch_reversed = nn.utils.rnn.pack_padded_sequence( batch_reversed, seqlens, batch_first=True) # compute temporal mask batch_mask = utils.generate_batch_mask(batch_onehot, seqlens).cuda() # perform evaluation val_nll += self.svi.evaluate_loss(batch_onehot, batch_reversed, batch_mask, seqlens) # resume training self.dmm.rnn.train() loss = val_nll / self.N_val_data return loss def _train_batch(self, train_iter, epoch): """ process a batch (single epoch) """ batch_loss = 0 epoch_loss = 0 for ii, batch in enumerate(iter(train_iter)): batch_data = Variable(batch.text[0].to(self.device)) seqlens = Variable(batch.text[1].to(self.device)) # transpose to [B, seqlen, vocab_size] shape batch_data = torch.t(batch_data) # compute one hot character embedding batch_onehot = nn.functional.one_hot(batch_data, self.vocab_size).float() # flip sequence for rnn batch_reversed = utils.reverse_seq(batch_onehot, seqlens) batch_reversed = nn.utils.rnn.pack_padded_sequence( batch_reversed, seqlens, batch_first=True) # compute temporal mask batch_mask = utils.generate_batch_mask(batch_onehot, seqlens).cuda() # compute kl-div annealing factor if self.kl_ae > 0 and epoch < self.kl_ae: min_af = self.maf kl_anneal = min_af + ( 1 - min_af) * (float(ii + epoch * self.N_batches + 1) / float(self.kl_ae * self.N_batches)) else: # default kl-div annealing factor is unity kl_anneal = 1.0 # take gradient step batch_loss = self.svi.step(batch_onehot, batch_reversed, batch_mask, seqlens, kl_anneal) batch_loss = batch_loss / (torch.sum(seqlens).float()) print("loss at iteration {0} is {1}".format(ii, batch_loss)) epoch_loss = epoch_loss + batch_loss return epoch_loss def train(self): """ trains a network with a given training set """ self.device = torch.device("cuda") np.random.seed(self.rand_seed) torch.manual_seed(self.rand_seed) train, val, test, vocab = datahandler.load_data( "./data/ptb", self.maxlen) self.vocab_size = len(vocab) # make iterable dataset object train_iter, val_iter, test_iter = torchtext.data.BucketIterator.splits( (train, val, test), batch_sizes=[self.batch_size, 1, 1], device=self.device, repeat=False, sort_key=lambda x: len(x.text), sort_within_batch=True, ) self.N_train_data = len(train) self.N_val_data = len(val) self.N_batches = int(self.N_train_data / self.batch_size + int(self.N_train_data % self.batch_size > 0)) self.N_train_data = len(train) self.N_val_data = len(val) self.N_batches = int(self.N_train_data / self.batch_size + int(self.N_train_data % self.batch_size > 0)) self.log("N_train_data: %d N_mini_batches: %d" % (self.N_train_data, self.N_batches)) # instantiate the dmm self.dmm = DMM(input_dim=self.vocab_size, dropout=self.dropout) # setup optimizer opt_params = { "lr": self.lr, "betas": (self.beta1, self.beta2), "clip_norm": self.cn, "lrd": self.lr_decay, "weight_decay": self.wd, } self.adam = ClippedAdam(opt_params) # set up inference algorithm self.elbo = Trace_ELBO() self.svi = SVI(self.dmm.model, self.dmm.guide, self.adam, loss=self.elbo) val_f = 10 print("training dmm") times = [time.time()] for epoch in range(self.n_epoch): if self.ckpt_f > 0 and epoch > 0 and epoch % self.ckpt_f == 0: self.save_ckpt() # train and report metrics train_nll = self._train_batch( train_iter, epoch, ) times.append(time.time()) t_elps = times[-1] - times[-2] self.log("epoch %04d -> train nll: %.4f \t t_elps=%.3f sec" % (epoch, train_nll, t_elps)) if epoch % val_f == 0: val_nll = self._validate(val_iter) pass def save_ckpt(self): """ saves the state of the network and optimizer for later """ self.log("saving model to %s" % self.save_model) torch.save(self.dmm.state_dict(), self.save_model) self.log("saving optimizer states to %s" % self.save_opt) self.adam.save(self.save_opt) pass def load_ckpt(self): """ loads a saved checkpoint """ assert exists(args.load_opt) and exists( args.load_model), "--load-model and/or --load-opt misspecified" self.log("loading model from %s..." % self.load_model) self.dmm.load_state_dict(torch.load(self.load_model)) self.log("loading optimizer states from %s..." % self.load_opt) self.adam.load(self.load_opt) pass