示例#1
0
def main(args):
    # Init tensorboard
    writer = SummaryWriter('./runs/' + args.runname + str(args.trialnumber))
    model_name = 'VanillaDMM'

    # Set evaluation log file
    evaluation_logpath = './logs/{}/evaluation_result.log'.format(
        model_name.lower())
    log_evaluation(evaluation_logpath,
                   'Evaluation Trial - {}\n'.format(args.trialnumber))

    # Constants
    time_length = 30
    input_length_for_pred = 20
    pred_length = time_length - input_length_for_pred
    train_batch_size = 16
    valid_batch_size = 1

    # For model
    input_channels = 1
    z_channels = 50
    emission_channels = [64, 32]
    transition_channels = 64
    encoder_channels = [32, 64]
    rnn_input_dim = 256
    rnn_channels = 128
    kernel_size = 3
    pred_length = 0

    # Device checking
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")

    # Make dataset
    logging.info("Generate data")
    train_datapath = args.datapath / 'train'
    valid_datapath = args.datapath / 'valid'
    train_dataset = DiffusionDataset(train_datapath)
    valid_dataset = DiffusionDataset(valid_datapath)

    # Create data loaders from pickle data
    logging.info("Generate data loaders")
    train_dataloader = DataLoader(
        train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=8)
    valid_dataloader = DataLoader(
        valid_dataset, batch_size=valid_batch_size, num_workers=4)

    # Training parameters
    width = 100
    height = 100
    input_dim = width * height

    # Create model
    logging.warning("Generate model")
    logging.warning(input_dim)
    pred_input_dim = 10
    dmm = DMM(input_channels=input_channels, z_channels=z_channels, emission_channels=emission_channels,
              transition_channels=transition_channels, encoder_channels=encoder_channels, rnn_input_dim=rnn_input_dim, rnn_channels=rnn_channels, kernel_size=kernel_size, height=height, width=width, pred_input_dim=pred_input_dim, num_layers=1, rnn_dropout_rate=0.0,
              num_iafs=0, iaf_dim=50, use_cuda=use_cuda)

    # Initialize model
    logging.info("Initialize model")
    epochs = args.endepoch
    learning_rate = 0.0001
    beta1 = 0.9
    beta2 = 0.999
    clip_norm = 10.0
    lr_decay = 1.0
    weight_decay = 0
    adam_params = {"lr": learning_rate, "betas": (beta1, beta2),
                   "clip_norm": clip_norm, "lrd": lr_decay,
                   "weight_decay": weight_decay}
    adam = ClippedAdam(adam_params)
    elbo = Trace_ELBO()
    svi = SVI(dmm.model, dmm.guide, adam, loss=elbo)

    # saves the model and optimizer states to disk
    save_model = Path('./checkpoints/' + model_name)

    def save_checkpoint(epoch):
        save_dir = save_model / '{}.model'.format(epoch)
        save_opt_dir = save_model / '{}.opt'.format(epoch)
        logging.info("saving model to %s..." % save_dir)
        torch.save(dmm.state_dict(), save_dir)
        logging.info("saving optimizer states to %s..." % save_opt_dir)
        adam.save(save_opt_dir)
        logging.info("done saving model and optimizer checkpoints to disk.")

    # Starting epoch
    start_epoch = args.startepoch

    # loads the model and optimizer states from disk
    if start_epoch != 0:
        load_opt = './checkpoints/' + model_name + \
            '/e{}-i188-opt-tn{}.opt'.format(start_epoch - 1, args.trialnumber)
        load_model = './checkpoints/' + model_name + \
            '/e{}-i188-tn{}.pt'.format(start_epoch - 1, args.trialnumber)

        def load_checkpoint():
            # assert exists(load_opt) and exists(load_model), \
            #     "--load-model and/or --load-opt misspecified"
            logging.info("loading model from %s..." % load_model)
            dmm.load_state_dict(torch.load(load_model, map_location=device))
            # logging.info("loading optimizer states from %s..." % load_opt)
            # adam.load(load_opt)
            # logging.info("done loading model and optimizer states.")

        if load_model != '':
            logging.info('Load checkpoint')
            load_checkpoint()

    # Validation only?
    validation_only = args.validonly

    # Train the model
    if not validation_only:
        logging.info("Training model")
        annealing_epochs = 1000
        minimum_annealing_factor = 0.2
        N_train_size = 3000
        N_mini_batches = int(N_train_size / train_batch_size +
                             int(N_train_size % train_batch_size > 0))
        for epoch in tqdm(range(start_epoch, epochs), desc='Epoch', leave=True):
            r_loss_train = 0
            dmm.train(True)
            idx = 0
            mov_avg_loss = 0
            mov_data_len = 0
            for which_mini_batch, data in enumerate(tqdm(train_dataloader, desc='Train', leave=True)):
                if annealing_epochs > 0 and epoch < annealing_epochs:
                    # compute the KL annealing factor approriate for the current mini-batch in the current epoch
                    min_af = minimum_annealing_factor
                    annealing_factor = min_af + (1.0 - min_af) * \
                        (float(which_mini_batch + epoch * N_mini_batches + 1) /
                         float(annealing_epochs * N_mini_batches))
                else:
                    # by default the KL annealing factor is unity
                    annealing_factor = 1.0

                data['observation'] = normalize(
                    data['observation'].unsqueeze(2).to(device))
                batch_size, length, _, w, h = data['observation'].shape
                data_reversed = reverse_sequences(data['observation'])
                data_mask = torch.ones(
                    batch_size, length, input_channels, w, h).cuda()

                loss = svi.step(data['observation'],
                                data_reversed, data_mask, annealing_factor)

                # Running losses
                mov_avg_loss += loss
                mov_data_len += batch_size

                r_loss_train += loss
                idx += 1

            # Average losses
            train_loss_avg = r_loss_train / (len(train_dataset) * time_length)
            writer.add_scalar('Loss/train', train_loss_avg, epoch)
            logging.info("Epoch: %d, Training loss: %1.5f",
                         epoch, train_loss_avg)

            # # Time to time evaluation
            if epoch == epochs - 1:
                for temp_pred_length in [20]:
                    r_loss_valid = 0
                    r_loss_loc_valid = 0
                    r_loss_scale_valid = 0
                    r_loss_latent_valid = 0
                    dmm.train(False)
                    val_pred_length = temp_pred_length
                    val_pred_input_length = 10
                    with torch.no_grad():
                        for i, data in enumerate(tqdm(valid_dataloader, desc='Eval', leave=True)):
                            data['observation'] = normalize(
                                data['observation'].unsqueeze(2).to(device))
                            batch_size, length, _, w, h = data['observation'].shape
                            data_reversed = reverse_sequences(
                                data['observation'])
                            data_mask = torch.ones(
                                batch_size, length, input_channels, w, h).cuda()

                            pred_tensor = data['observation'][:,
                                                              :input_length_for_pred, :, :, :]
                            pred_tensor_reversed = reverse_sequences(
                                pred_tensor)
                            pred_tensor_mask = torch.ones(
                                batch_size, input_length_for_pred, input_channels, w, h).cuda()

                            ground_truth = data['observation'][:,
                                                               input_length_for_pred:, :, :, :]

                            val_nll = svi.evaluate_loss(
                                data['observation'], data_reversed, data_mask)

                            preds, _, loss_loc, loss_scale = do_prediction_rep_inference(
                                dmm, pred_tensor_mask, val_pred_length, val_pred_input_length, data['observation'])

                            ground_truth = denormalize(
                                data['observation'].squeeze().cpu().detach()
                            )
                            pred_with_input = denormalize(
                                torch.cat(
                                    [data['observation'][:, :-val_pred_length, :, :, :].squeeze(),
                                     preds.squeeze()], dim=0
                                ).cpu().detach()
                            )

                            # Running losses
                            r_loss_valid += val_nll
                            r_loss_loc_valid += loss_loc
                            r_loss_scale_valid += loss_scale

                    # Average losses
                    valid_loss_avg = r_loss_valid / \
                        (len(valid_dataset) * time_length)
                    valid_loss_loc_avg = r_loss_loc_valid / \
                        (len(valid_dataset) * val_pred_length * width * height)
                    valid_loss_scale_avg = r_loss_scale_valid / \
                        (len(valid_dataset) * val_pred_length * width * height)
                    writer.add_scalar('Loss/test', valid_loss_avg, epoch)
                    writer.add_scalar(
                        'Loss/test_obs', valid_loss_loc_avg, epoch)
                    writer.add_scalar('Loss/test_scale',
                                      valid_loss_scale_avg, epoch)
                    logging.info("Validation loss: %1.5f", valid_loss_avg)
                    logging.info("Validation obs loss: %1.5f",
                                 valid_loss_loc_avg)
                    logging.info("Validation scale loss: %1.5f",
                                 valid_loss_scale_avg)
                    log_evaluation(evaluation_logpath, "Validation obs loss for {}s pred {}: {}\n".format(
                        val_pred_length, args.trialnumber, valid_loss_loc_avg))
                    log_evaluation(evaluation_logpath, "Validation scale loss for {}s pred {}: {}\n".format(
                        val_pred_length, args.trialnumber, valid_loss_scale_avg))

            # Save model
            if epoch % 50 == 0 or epoch == epochs - 1:
                torch.save(dmm.state_dict(), args.modelsavepath / model_name /
                           'e{}-i{}-tn{}.pt'.format(epoch, idx, args.trialnumber))
                adam.save(args.modelsavepath / model_name /
                          'e{}-i{}-opt-tn{}.opt'.format(epoch, idx, args.trialnumber))

    # Last validation after training
    test_samples_indices = range(100)
    total_n = 0
    if validation_only:
        r_loss_loc_valid = 0
        r_loss_scale_valid = 0
        r_loss_latent_valid = 0
        dmm.train(False)
        val_pred_length = args.validpredlength
        val_pred_input_length = 10
        with torch.no_grad():
            for i in tqdm(test_samples_indices, desc='Valid', leave=True):
                # Data processing
                data = valid_dataset[i]
                if torch.isnan(torch.sum(data['observation'])):
                    print("Skip {}".format(i))
                    continue
                else:
                    total_n += 1
                data['observation'] = normalize(
                    data['observation'].unsqueeze(0).unsqueeze(2).to(device))
                batch_size, length, _, w, h = data['observation'].shape
                data_reversed = reverse_sequences(data['observation'])
                data_mask = torch.ones(
                    batch_size, length, input_channels, w, h).to(device)

                # Prediction
                pred_tensor_mask = torch.ones(
                    batch_size, input_length_for_pred, input_channels, w, h).to(device)
                preds, _, loss_loc, loss_scale = do_prediction_rep_inference(
                    dmm, pred_tensor_mask, val_pred_length, val_pred_input_length, data['observation'])

                ground_truth = denormalize(
                    data['observation'].squeeze().cpu().detach()
                )
                pred_with_input = denormalize(
                    torch.cat(
                        [data['observation'][:, :-val_pred_length, :, :, :].squeeze(),
                         preds.squeeze()], dim=0
                    ).cpu().detach()
                )

                # Save samples
                if i < 5:
                    save_dir_samples = Path('./samples/more_variance_long')
                    with open(save_dir_samples / '{}-gt-test.pkl'.format(i), 'wb') as fout:
                        pickle.dump(ground_truth, fout)
                    with open(save_dir_samples / '{}-vanilladmm-pred-test.pkl'.format(i), 'wb') as fout:
                        pickle.dump(pred_with_input, fout)

                # Running losses
                r_loss_loc_valid += loss_loc
                r_loss_scale_valid += loss_scale
                r_loss_latent_valid += np.sum((preds.squeeze().detach().cpu().numpy(
                ) - data['latent'][time_length - val_pred_length:, :, :].detach().cpu().numpy()) ** 2)

        # Average losses
        test_samples_indices = range(total_n)
        print(total_n)
        valid_loss_loc_avg = r_loss_loc_valid / \
            (total_n * val_pred_length * width * height)
        valid_loss_scale_avg = r_loss_scale_valid / \
            (total_n * val_pred_length * width * height)
        valid_loss_latent_avg = r_loss_latent_valid / \
            (total_n * val_pred_length * width * height)
        logging.info("Validation obs loss for %ds pred VanillaDMM: %f",
                     val_pred_length, valid_loss_loc_avg)
        logging.info("Validation latent loss: %f", valid_loss_latent_avg)

        with open('VanillaDMMResult.log', 'a+') as fout:
            validation_log = 'Pred {}s VanillaDMM: {}\n'.format(
                val_pred_length, valid_loss_loc_avg)
            fout.write(validation_log)
示例#2
0
class Trainer(object):
    """
    trainer class used to instantiate, train and validate a network
    """
    def __init__(self, args):

        # argument gathering
        self.rand_seed = args.rand_seed
        self.dev_num = args.dev_num
        self.cuda = args.cuda
        self.n_epoch = args.n_epoch
        self.batch_size = args.batch_size
        self.lr = args.lr
        self.beta1 = args.beta1
        self.beta2 = args.beta2
        self.wd = args.wd
        self.cn = args.cn
        self.lr_decay = args.lr_decay
        self.kl_ae = args.kl_ae
        self.maf = args.maf
        self.dropout = args.dropout
        self.ckpt_f = args.ckpt_f
        self.load_opt = args.load_opt
        self.load_model = args.load_model
        self.save_opt = args.save_opt
        self.save_model = args.save_model
        self.maxlen = args.maxlen
        # setup logging
        self.log = utils.get_logger(args.log)
        self.log(args)

    def _validate(self, val_iter):
        """
        freezes training and validates on the network with a validation set
        """
        # freeze training
        self.dmm.rnn.eval()
        val_nll = 0
        for ii, batch in enumerate(iter(val_iter)):

            batch_data = Variable(batch.text[0].to(self.device))
            seqlens = Variable(batch.text[1].to(self.device))

            # transpose to [B, seqlen, vocab_size] shape
            batch_data = torch.t(batch_data)
            # compute one hot character embedding
            batch_onehot = nn.functional.one_hot(batch_data,
                                                 self.vocab_size).float()
            # flip sequence for rnn
            batch_reversed = utils.reverse_seq(batch_onehot, seqlens)
            batch_reversed = nn.utils.rnn.pack_padded_sequence(
                batch_reversed, seqlens, batch_first=True)
            # compute temporal mask
            batch_mask = utils.generate_batch_mask(batch_onehot,
                                                   seqlens).cuda()
            # perform evaluation
            val_nll += self.svi.evaluate_loss(batch_onehot, batch_reversed,
                                              batch_mask, seqlens)

        # resume training
        self.dmm.rnn.train()

        loss = val_nll / self.N_val_data

        return loss

    def _train_batch(self, train_iter, epoch):
        """
        process a batch (single epoch)
        """
        batch_loss = 0
        epoch_loss = 0
        for ii, batch in enumerate(iter(train_iter)):

            batch_data = Variable(batch.text[0].to(self.device))
            seqlens = Variable(batch.text[1].to(self.device))

            # transpose to [B, seqlen, vocab_size] shape
            batch_data = torch.t(batch_data)
            # compute one hot character embedding
            batch_onehot = nn.functional.one_hot(batch_data,
                                                 self.vocab_size).float()
            # flip sequence for rnn
            batch_reversed = utils.reverse_seq(batch_onehot, seqlens)
            batch_reversed = nn.utils.rnn.pack_padded_sequence(
                batch_reversed, seqlens, batch_first=True)
            # compute temporal mask
            batch_mask = utils.generate_batch_mask(batch_onehot,
                                                   seqlens).cuda()

            # compute kl-div annealing factor
            if self.kl_ae > 0 and epoch < self.kl_ae:
                min_af = self.maf
                kl_anneal = min_af + (
                    1 - min_af) * (float(ii + epoch * self.N_batches + 1) /
                                   float(self.kl_ae * self.N_batches))
            else:
                # default kl-div annealing factor is unity
                kl_anneal = 1.0

            # take gradient step
            batch_loss = self.svi.step(batch_onehot, batch_reversed,
                                       batch_mask, seqlens, kl_anneal)
            batch_loss = batch_loss / (torch.sum(seqlens).float())
            print("loss at iteration {0} is {1}".format(ii, batch_loss))
        epoch_loss = epoch_loss + batch_loss

        return epoch_loss

    def train(self):
        """
        trains a network with a given training set
        """
        self.device = torch.device("cuda")
        np.random.seed(self.rand_seed)
        torch.manual_seed(self.rand_seed)

        train, val, test, vocab = datahandler.load_data(
            "./data/ptb", self.maxlen)

        self.vocab_size = len(vocab)

        # make iterable dataset object
        train_iter, val_iter, test_iter = torchtext.data.BucketIterator.splits(
            (train, val, test),
            batch_sizes=[self.batch_size, 1, 1],
            device=self.device,
            repeat=False,
            sort_key=lambda x: len(x.text),
            sort_within_batch=True,
        )
        self.N_train_data = len(train)
        self.N_val_data = len(val)
        self.N_batches = int(self.N_train_data / self.batch_size +
                             int(self.N_train_data % self.batch_size > 0))

        self.N_train_data = len(train)
        self.N_val_data = len(val)
        self.N_batches = int(self.N_train_data / self.batch_size +
                             int(self.N_train_data % self.batch_size > 0))
        self.log("N_train_data: %d  N_mini_batches: %d" %
                 (self.N_train_data, self.N_batches))

        # instantiate the dmm
        self.dmm = DMM(input_dim=self.vocab_size, dropout=self.dropout)

        # setup optimizer
        opt_params = {
            "lr": self.lr,
            "betas": (self.beta1, self.beta2),
            "clip_norm": self.cn,
            "lrd": self.lr_decay,
            "weight_decay": self.wd,
        }
        self.adam = ClippedAdam(opt_params)
        # set up inference algorithm
        self.elbo = Trace_ELBO()
        self.svi = SVI(self.dmm.model,
                       self.dmm.guide,
                       self.adam,
                       loss=self.elbo)

        val_f = 10

        print("training dmm")
        times = [time.time()]
        for epoch in range(self.n_epoch):

            if self.ckpt_f > 0 and epoch > 0 and epoch % self.ckpt_f == 0:
                self.save_ckpt()

            # train and report metrics
            train_nll = self._train_batch(
                train_iter,
                epoch,
            )

            times.append(time.time())
            t_elps = times[-1] - times[-2]
            self.log("epoch %04d -> train nll: %.4f \t t_elps=%.3f sec" %
                     (epoch, train_nll, t_elps))

            if epoch % val_f == 0:
                val_nll = self._validate(val_iter)
        pass

    def save_ckpt(self):
        """
        saves the state of the network and optimizer for later
        """
        self.log("saving model to %s" % self.save_model)
        torch.save(self.dmm.state_dict(), self.save_model)
        self.log("saving optimizer states to %s" % self.save_opt)
        self.adam.save(self.save_opt)

        pass

    def load_ckpt(self):
        """
        loads a saved checkpoint
        """
        assert exists(args.load_opt) and exists(
            args.load_model), "--load-model and/or --load-opt misspecified"
        self.log("loading model from %s..." % self.load_model)
        self.dmm.load_state_dict(torch.load(self.load_model))
        self.log("loading optimizer states from %s..." % self.load_opt)
        self.adam.load(self.load_opt)

        pass