예제 #1
0
def train_discriminator(dis, gen, criterion, optimizer, epochs,
                        dis_adversarial_train_loss, dis_adversarial_train_acc,
                        args):
    """
    Train discriminator
    """
    generate_samples(gen, args.batch_size, args.n_samples, NEGATIVE_FILE)
    data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE, args.batch_size)
    for epoch in range(epochs):
        correct = 0
        total_loss = 0.
        for data, target in data_iter:
            if args.cuda:
                data, target = data.cuda(), target.cuda()
            target = target.contiguous().view(-1)
            output = dis(data)
            pred = output.data.max(1)[1]
            correct += pred.eq(target.data).cpu().sum()
            loss = criterion(output, target)
            total_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        data_iter.reset()
        avg_loss = total_loss / len(data_iter)
        acc = correct.item() / data_iter.data_num
        print("Epoch {}, train loss: {:.5f}, train acc: {:.3f}".format(
            epoch, avg_loss, acc))
        dis_adversarial_train_loss.append(avg_loss)
        dis_adversarial_train_acc.append(acc)
예제 #2
0
                total_words = 0.                        
                n = 0 
                for (data, target) in dis_data_iter:
                    n+=1
                    data = Variable(data)
                    target = Variable(target)
                    if opt.cuda:
                        data, target = data.cuda(), target.cuda()
                    target = target.contiguous().view(-1) 
                    pred = discriminator.forward(data) 
                    loss = dis_criterion(pred, target) # negative log likelihood loss                            
                    total_loss += loss.item()
                    total_words += data.size(0) * data.size(1)       
                    
                    dis_optimizer.zero_grad() 
                    loss.backward()     
                    dis_optimizer.step()      
                
                dis_data_iter.reset() 
                f_loss = math.exp(total_loss/ total_words) 

                d_save_path = os.path.join(m_save_path, c_cat)
                if not os.path.exists(d_save_path):
                    os.mkdir(d_save_path)
                d_save_path = os.path.join(d_save_path, 'discriminator'+str(total_batch)+'.pkl')                        
                # torch.save(discriminator.state_dict(), d_save_path)
        print('total_d_loss ', total_loss)
        print('f_d_loss ',f_loss)