Exemplo n.º 1
0
Arquivo: train.py Projeto: yyht/rrws
    def __call__(self, iteration, loss, elbo, generative_model,
                 inference_network, control_variate):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format(
                    iteration, loss, elbo))
            self.loss_history.append(loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_filename = util.get_stats_filename(self.model_folder)
            util.save_object(self, stats_filename)
            util.save_models(generative_model, inference_network,
                             self.pcfg_path, self.model_folder)
            util.save_control_variate(control_variate, self.model_folder)

        if iteration % self.eval_interval == 0:
            self.p_error_history.append(
                util.get_p_error(self.true_generative_model, generative_model))
            self.q_error_to_true_history.append(
                util.get_q_error(self.true_generative_model,
                                 inference_network))
            self.q_error_to_model_history.append(
                util.get_q_error(generative_model, inference_network))
            util.print_with_time(
                'Iteration {} p_error = {:.3f}, q_error_to_true = {:.3f}, '
                'q_error_to_model = {:.3f}'.format(
                    iteration, self.p_error_history[-1],
                    self.q_error_to_true_history[-1],
                    self.q_error_to_model_history[-1]))
Exemplo n.º 2
0
Arquivo: train.py Projeto: yyht/rrws
    def __call__(self, iteration, wake_theta_loss, wake_phi_loss, elbo,
                 generative_model, inference_network, optimizer_theta,
                 optimizer_phi):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} losses: theta = {:.3f}, phi = {:.3f}, elbo = '
                '{:.3f}'.format(iteration, wake_theta_loss, wake_phi_loss,
                                elbo))
            self.wake_theta_loss_history.append(wake_theta_loss)
            self.wake_phi_loss_history.append(wake_phi_loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_filename = util.get_stats_filename(self.model_folder)
            util.save_object(self, stats_filename)
            util.save_models(generative_model, inference_network,
                             self.pcfg_path, self.model_folder)

        if iteration % self.eval_interval == 0:
            self.p_error_history.append(
                util.get_p_error(self.true_generative_model, generative_model))
            self.q_error_to_true_history.append(
                util.get_q_error(self.true_generative_model,
                                 inference_network))
            self.q_error_to_model_history.append(
                util.get_q_error(generative_model, inference_network))
            util.print_with_time(
                'Iteration {} p_error = {:.3f}, q_error_to_true = {:.3f}, '
                'q_error_to_model = {:.3f}'.format(
                    iteration, self.p_error_history[-1],
                    self.q_error_to_true_history[-1],
                    self.q_error_to_model_history[-1]))
Exemplo n.º 3
0
    def __call__(self, iteration, theta_loss, phi_loss, generative_model,
                 inference_network, memory, optimizer):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} losses: theta = {:.3f}, phi = {:.3f}'.format(
                    iteration, theta_loss, phi_loss))
            self.theta_loss_history.append(theta_loss)
            self.phi_loss_history.append(phi_loss)

        if iteration % self.checkpoint_interval == 0:
            stats_filename = util.get_stats_path(self.model_folder)
            util.save_object(self, stats_filename)
            util.save_models(generative_model, inference_network,
                             self.model_folder, iteration, memory)

        if iteration % self.eval_interval == 0:
            self.p_error_history.append(
                util.get_p_error(self.true_generative_model, generative_model))
            self.q_error_history.append(
                util.get_q_error(self.true_generative_model, inference_network,
                                 self.test_obss))
            # TODO
            # self.memory_error_history.append(util.get_memory_error(
            #     self.true_generative_model, memory, generative_model,
            #     self.test_obss))
            util.print_with_time(
                'Iteration {} p_error = {:.3f}, q_error_to_true = '
                '{:.3f}'.format(iteration, self.p_error_history[-1],
                                self.q_error_history[-1]))
Exemplo n.º 4
0
def run(args):
    util.print_with_time(str(args))

    # save args
    model_folder = util.get_model_folder()
    args_filename = util.get_args_filename(model_folder)
    util.save_object(args, args_filename)

    # init models
    util.set_seed(args.seed)
    generative_model, inference_network, true_generative_model = \
        load_or_init_models(args.load_model_folder, args.pcfg_path)
    if args.train_mode == 'relax':
        control_variate = models.ControlVariate(generative_model.grammar)

    # train
    if args.train_mode == 'ww':
        train_callback = train.TrainWakeWakeCallback(
            args.pcfg_path, model_folder, true_generative_model,
            args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_wake_wake(generative_model, inference_network,
                              true_generative_model, args.batch_size,
                              args.num_iterations, args.num_particles,
                              train_callback)
    elif args.train_mode == 'ws':
        train_callback = train.TrainWakeSleepCallback(
            args.pcfg_path, model_folder, true_generative_model,
            args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_wake_sleep(generative_model, inference_network,
                               true_generative_model, args.batch_size,
                               args.num_iterations, args.num_particles,
                               train_callback)
    elif args.train_mode == 'reinforce' or args.train_mode == 'vimco':
        train_callback = train.TrainIwaeCallback(
            args.pcfg_path, model_folder, true_generative_model,
            args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_iwae(args.train_mode, generative_model, inference_network,
                         true_generative_model, args.batch_size,
                         args.num_iterations, args.num_particles,
                         train_callback)
    elif args.train_mode == 'relax':
        train_callback = train.TrainRelaxCallback(
            args.pcfg_path, model_folder, true_generative_model,
            args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_relax(generative_model, inference_network, control_variate,
                          true_generative_model, args.batch_size,
                          args.num_iterations, args.num_particles,
                          train_callback)

    # save models and stats
    util.save_models(generative_model, inference_network, args.pcfg_path,
                     model_folder)
    stats_filename = util.get_stats_filename(model_folder)
    util.save_object(train_callback, stats_filename)
Exemplo n.º 5
0
    def __call__(self, iteration, loss, elbo, generative_model,
                 inference_network, optimizer):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format(
                    iteration, loss, elbo))
            self.loss_history.append(loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_filename = util.get_stats_path(self.model_folder)
            util.save_object(self, stats_filename)
            util.save_models(generative_model, inference_network,
                             self.model_folder, iteration)

        if iteration % self.eval_interval == 0:
            self.p_error_history.append(
                util.get_p_error(self.true_generative_model, generative_model))
            self.q_error_history.append(
                util.get_q_error(self.true_generative_model, inference_network,
                                 self.test_obss))
            stats = util.OnlineMeanStd()
            for _ in range(10):
                inference_network.zero_grad()
                if self.train_mode == 'vimco':
                    loss, elbo = losses.get_vimco_loss(generative_model,
                                                       inference_network,
                                                       self.test_obss,
                                                       self.num_particles)
                elif self.train_mode == 'reinforce':
                    loss, elbo = losses.get_reinforce_loss(
                        generative_model, inference_network, self.test_obss,
                        self.num_particles)
                loss.backward()
                stats.update([p.grad for p in inference_network.parameters()])
            self.grad_std_history.append(stats.avg_of_means_stds()[1])
            util.print_with_time(
                'Iteration {} p_error = {:.3f}, q_error_to_true = '
                '{:.3f}'.format(iteration, self.p_error_history[-1],
                                self.q_error_history[-1]))
Exemplo n.º 6
0
def train_reconstruction(train_loader, test_loader, encoder, decoder, args):
    exp = Experiment("Reconstruction Training")
    try:
        lr = args.lr
        encoder_opt = torch.optim.Adam(encoder.parameters(), lr=lr)
        decoder_opt = torch.optim.Adam(decoder.parameters(), lr=lr)

        encoder.train()
        decoder.train()
        steps = 0
        for epoch in range(1, args.epochs+1):
            print("=======Epoch========")
            print(epoch)
            for batch in train_loader:
                feature = Variable(batch)
                if args.use_cuda:
                    encoder.cuda()
                    decoder.cuda()
                    feature = feature.cuda()

                encoder_opt.zero_grad()
                decoder_opt.zero_grad()

                h = encoder(feature)
                prob = decoder(h)
                reconstruction_loss = compute_cross_entropy(prob, feature)
                reconstruction_loss.backward()
                encoder_opt.step()
                decoder_opt.step()

                steps += 1
                print("Epoch: {}".format(epoch))
                print("Steps: {}".format(steps))
                print("Loss: {}".format(reconstruction_loss.data[0] / args.sentence_len))
                exp.metric("Loss", reconstruction_loss.data[0] / args.sentence_len)
                # check reconstructed sentence
                if steps % args.log_interval == 0:
                    print("Test!!")
                    input_data = feature[0]
                    single_data = prob[0]
                    _, predict_index = torch.max(single_data, 1)
                    input_sentence = util.transform_id2word(input_data.data, train_loader.dataset.index2word, lang="en")
                    predict_sentence = util.transform_id2word(predict_index.data, train_loader.dataset.index2word, lang="en")
                    print("Input Sentence:")
                    print(input_sentence)
                    print("Output Sentence:")
                    print(predict_sentence)

            if steps % args.test_interval == 0:
                eval_reconstruction(encoder, decoder, test_loader, args)


            if epoch % args.lr_decay_interval == 0:
                # decrease learning rate
                lr = lr / 5
                encoder_opt = torch.optim.Adam(encoder.parameters(), lr=lr)
                decoder_opt = torch.optim.Adam(decoder.parameters(), lr=lr)
                encoder.train()
                decoder.train()

            if epoch % args.save_interval == 0:
                util.save_models(encoder, args.save_dir, "encoder", steps)
                util.save_models(decoder, args.save_dir, "decoder", steps)

        # finalization
        # save vocabulary
        with open("word2index", "wb") as w2i, open("index2word", "wb") as i2w:
            pickle.dump(train_loader.dataset.word2index, w2i)
            pickle.dump(train_loader.dataset.index2word, i2w)

        # save models
        util.save_models(encoder, args.save_dir, "encoder", "final")
        util.save_models(decoder, args.save_dir, "decoder", "final")

        print("Finish!!!")
    finally:
        exp.end()
Exemplo n.º 7
0
def train_classification(data_loader, dev_iter, encoder, decoder, mlp, args):
    lr = args.lr
    encoder_opt = torch.optim.Adam(encoder.parameters(), lr=lr)
    decoder_opt = torch.optim.Adam(decoder.parameters(), lr=lr)
    mlp_opt = torch.optim.Adam(mlp.parameters(), lr=lr)

    encoder.train()
    decoder.train()
    mlp.train()
    steps = 0
    for epoch in range(1, args.epochs+1):
        alpha = util.sigmoid_annealing_schedule(epoch, args.epochs)
        print("=======Epoch========")
        print(epoch)
        for batch in data_loader:
            feature, target = Variable(batch["sentence"]), Variable(batch["label"])
            if args.use_cuda:
                encoder.cuda()
                decoder.cuda()
                mlp.cuda()
                feature, target = feature.cuda(), target.cuda()

            encoder_opt.zero_grad()
            decoder_opt.zero_grad()
            mlp_opt.zero_grad()

            h = encoder(feature)
            prob = decoder(h)
            log_prob = mlp(h.squeeze())
            reconstruction_loss = compute_cross_entropy(prob, feature)
            supervised_loss = F.nll_loss(log_prob, target.view(target.size()[0]))
            loss = alpha * reconstruction_loss + supervised_loss
            loss.backward()
            encoder_opt.step()
            decoder_opt.step()
            mlp_opt.step()

            steps += 1
            print("Epoch: {}".format(epoch))
            print("Steps: {}".format(steps))
            print("Loss: {}".format(loss.data[0]))
            # check reconstructed sentence and classification
            if steps % args.log_interval == 0:
                print("Test!!")
                input_data = feature[0]
                input_label = target[0]
                single_data = prob[0]
                _, predict_index = torch.max(single_data, 1)
                input_sentence = util.transform_id2word(input_data.data, data_loader.dataset.index2word, lang="ja")
                predict_sentence = util.transform_id2word(predict_index.data, data_loader.dataset.index2word, lang="ja")
                print("Input Sentence:")
                print(input_sentence)
                print("Output Sentence:")
                print(predict_sentence)
                eval_classification(encoder, mlp, input_data, input_label)

        if epoch % args.lr_decay_interval == 0:
            # decrease learning rate
            lr = lr / 5
            encoder_opt = torch.optim.Adam(encoder.parameters(), lr=lr)
            decoder_opt = torch.optim.Adam(decoder.parameters(), lr=lr)
            mlp_opt = torch.optim.Adam(mlp.parameters(), lr=lr)
            encoder.train()
            decoder.train()
            mlp.train()

        if epoch % args.save_interval == 0:
            util.save_models(encoder, args.save_dir, "encoder", steps)
            util.save_models(decoder, args.save_dir, "decoder", steps)
            util.save_models(mlp, args.save_dir, "mlp", steps)

    # finalization
    # save vocabulary
    with open("word2index", "wb") as w2i, open("index2word", "wb") as i2w:
        pickle.dump(data_loader.dataset.word2index, w2i)
        pickle.dump(data_loader.dataset.index2word, i2w)

    # save models
    util.save_models(encoder, args.save_dir, "encoder", "final")
    util.save_models(decoder, args.save_dir, "decoder", "final")
    util.save_models(mlp, args.save_dir, "mlp", "final")

    print("Finish!!!")
Exemplo n.º 8
0
def run(args):
    # set up args
    args.device = None
    if args.cuda and torch.cuda.is_available():
        args.device = torch.device('cuda')
    else:
        args.device = torch.device('cpu')
    args.num_mixtures = 20
    if args.init_near:
        args.init_mixture_logits = np.ones(args.num_mixtures)
    else:
        args.init_mixture_logits = np.array(
            list(reversed(2 * np.arange(args.num_mixtures))))
    args.softmax_multiplier = 0.5
    if args.train_mode == 'concrete':
        args.relaxed_one_hot = True
        args.temperature = 3
    else:
        args.relaxed_one_hot = False
        args.temperature = None
    temp = np.arange(args.num_mixtures) + 5
    true_p_mixture_probs = temp / np.sum(temp)
    args.true_mixture_logits = \
        np.log(true_p_mixture_probs) / args.softmax_multiplier
    util.print_with_time(str(args))

    # save args
    model_folder = util.get_model_folder()
    args_filename = util.get_args_path(model_folder)
    util.save_object(args, args_filename)

    # init models
    util.set_seed(args.seed)
    generative_model, inference_network, true_generative_model = \
        util.init_models(args)
    if args.train_mode == 'relax':
        control_variate = models.ControlVariate(args.num_mixtures)

    # init dataloader
    obss_data_loader = torch.utils.data.DataLoader(
        true_generative_model.sample_obs(args.num_obss),
        batch_size=args.batch_size,
        shuffle=True)

    # train
    if args.train_mode == 'mws':
        train_callback = train.TrainMWSCallback(model_folder,
                                                true_generative_model,
                                                args.logging_interval,
                                                args.checkpoint_interval,
                                                args.eval_interval)
        train.train_mws(generative_model, inference_network, obss_data_loader,
                        args.num_iterations, args.mws_memory_size,
                        train_callback)
    if args.train_mode == 'ws':
        train_callback = train.TrainWakeSleepCallback(
            model_folder, true_generative_model,
            args.batch_size * args.num_particles, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_wake_sleep(generative_model, inference_network,
                               obss_data_loader, args.num_iterations,
                               args.num_particles, train_callback)
    elif args.train_mode == 'ww':
        train_callback = train.TrainWakeWakeCallback(model_folder,
                                                     true_generative_model,
                                                     args.num_particles,
                                                     args.logging_interval,
                                                     args.checkpoint_interval,
                                                     args.eval_interval)
        train.train_wake_wake(generative_model, inference_network,
                              obss_data_loader, args.num_iterations,
                              args.num_particles, train_callback)
    elif args.train_mode == 'dww':
        train_callback = train.TrainDefensiveWakeWakeCallback(
            model_folder, true_generative_model, args.num_particles, 0.2,
            args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_defensive_wake_wake(0.2, generative_model,
                                        inference_network, obss_data_loader,
                                        args.num_iterations,
                                        args.num_particles, train_callback)
    elif args.train_mode == 'reinforce' or args.train_mode == 'vimco':
        train_callback = train.TrainIwaeCallback(
            model_folder, true_generative_model, args.num_particles,
            args.train_mode, args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_iwae(args.train_mode, generative_model, inference_network,
                         obss_data_loader, args.num_iterations,
                         args.num_particles, train_callback)
    elif args.train_mode == 'concrete':
        train_callback = train.TrainConcreteCallback(
            model_folder, true_generative_model, args.num_particles,
            args.num_iterations, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_iwae(args.train_mode, generative_model, inference_network,
                         obss_data_loader, args.num_iterations,
                         args.num_particles, train_callback)
    elif args.train_mode == 'relax':
        train_callback = train.TrainRelaxCallback(model_folder,
                                                  true_generative_model,
                                                  args.num_particles,
                                                  args.logging_interval,
                                                  args.checkpoint_interval,
                                                  args.eval_interval)
        train.train_relax(generative_model, inference_network, control_variate,
                          obss_data_loader, args.num_iterations,
                          args.num_particles, train_callback)

    # save models and stats
    util.save_models(generative_model, inference_network, model_folder)
    if args.train_mode == 'relax':
        util.save_control_variate(control_variate, model_folder)
    stats_filename = util.get_stats_path(model_folder)
    util.save_object(train_callback, stats_filename)
Exemplo n.º 9
0
def main(file_train_path,stop_path,rare_len,epochs,embed_dim,batch_size,shuffle,sentence_len,filter_size,latent_size,n_class,LR,save_interval,save_dir,use_cuda):
    data_all,labels,rare_word, word2index, index2word = deal_with_data(file_path = file_train_path,stop_path = stop_path,sentence_len = sentence_len,rare_len =rare_len,rare_word = []).word_to_id()   
    
    ###一共有多少个词
    counts_words_len = len(word2index)
    
    ###一共多少个样本
    sample_len = len(labels)
    
    train_data = data_set(data_all[0:int(0.7*sample_len)],labels[0:int(0.7*sample_len)],transform =ToTensor())
    test_data = data_set(data_all[int(0.7*sample_len):sample_len],labels[int(0.7*sample_len):sample_len],transform =ToTensor())
    
    
    data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=shuffle)
    test_loader = DataLoader(test_data, batch_size=len(test_data)/100, shuffle=shuffle)
               
    # 做embedding                       
    embedding = nn.Embedding(counts_words_len, embed_dim, max_norm=1.0, norm_type=2.0)
    
    #构建textCNN模型
    cnn = textcnn_me.TextCNN(embedding = embedding, sentence_len = sentence_len, filter_size = filter_size, latent_size = latent_size,n_class = n_class)
    
    cnn_opt = torch.optim.Adam(cnn.parameters(), lr=LR)
    
    #损失函数
    loss_function = nn.CrossEntropyLoss()
    
    steps = 0
    for epoch in range(1, epochs+1):
        
        
        print("=======Epoch========")
        print(epoch)
        for batch in data_loader:
            feature, target = Variable(batch["sentence"]), Variable(batch["label"])
            if use_cuda:
                cnn.cuda()
                feature, target = feature.cuda(), target.cuda()
            
            cnn_opt.zero_grad()
            
            output = cnn(feature)
            
            #print(output)
            #print(target.view(target.size()[0]))
            loss = loss_function(output, target.view(target.size()[0]))
            loss.backward()
            cnn_opt.step()
           
    
            steps += 1
            print("Epoch: {}".format(epoch))
            print("Steps: {}".format(steps))
            print("Loss: {}".format(loss.data[0]))
            
    
            if epoch % save_interval == 0:
                util.save_models(cnn, save_dir, "cnn", epoch)
        
            for batch in test_loader:
                test_feature,test_target = Variable(batch["sentence"]),Variable(batch["label"])
                test_output =cnn(test_feature)    
                pred_y = torch.max(test_output,1)[1]
                acc = (test_target.view(test_target.size()[0]) == pred_y)
                acc = acc.numpy().sum()
                accuracy = acc / (test_target.size(0))
                print(len(pred_y))
                print('test_acc:{}'.format(accuracy))
        
        
    
    print("Finish!!!")
    return cnn