예제 #1
0
파일: Main.py 프로젝트: brycexu/SAA
def main(parser, logger):
    print('--> Preparing Dataset:')
    trainset = dataloader(parser, 'train')
    valset = dataloader(parser, 'val')
    testset = dataloader(parser, 'test')
    print('--> Preparing Word Embedding Model')
    print('--> Building Model:')
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model = SAAMNetwork()
    #model = nn.DataParallel(model, device_ids=[0,1,2,3])
    model = model.to(device)
    print('--> Initializing Optimizer and Scheduler')
    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=parser.learning_rate,
                                 weight_decay=0.0005)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                                     gamma=0.5,
                                                     milestones=[30, 60, 90])
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    best_acc = 0
    best_model_path = os.path.join(parser.experiment_root, 'best_model.pth')
    best_state = model.state_dict()
    for epoch in range(parser.epochs):
        print('\nEpoch: %d' % epoch)
        # Training
        model.train()
        for batch_index, (inputs, targets,
                          targets_embedding) in enumerate(trainset):
            support_samples, support_embeddings, query_samples = \
                splitInput(inputs, targets, targets_embedding, parser.num_support_tr)
            support_samples = support_samples.to(device)
            query_samples = query_samples.to(device)
            targets = targets.to(device)
            support_embeddings = support_embeddings.to(device)
            support_output, query_output, loss_output = model(
                x=support_samples,
                y=query_samples,
                x_emb=support_embeddings,
                n_class=parser.classes_per_it_tr,
                n_support=parser.num_support_tr)
            loss_pred, acc = loss_fn(support_input=support_output,
                                     query_input=query_output,
                                     targets=targets,
                                     n_support=parser.num_support_tr)
            loss = loss_pred + loss_output
            train_loss.append(loss_pred.item())
            train_acc.append(acc.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        scheduler.step()
        avg_loss = np.mean(train_loss[-parser.iterations:])
        avg_acc = 100. * np.mean(train_acc[-parser.iterations:])
        print('Training Loss: {} | Accuracy: {}'.format(avg_loss, avg_acc))
        # Validating
        model.eval()
        for batch_index, (inputs, targets,
                          targets_embedding) in enumerate(valset):
            support_samples, support_embeddings, query_samples = \
                splitInput(inputs, targets, targets_embedding, parser.num_support_tr)
            support_samples = support_samples.to(device)
            query_samples = query_samples.to(device)
            targets = targets.to(device)
            support_embeddings = support_embeddings.to(device)
            support_output, query_output, _ = model(
                x=support_samples,
                y=query_samples,
                x_emb=support_embeddings,
                n_class=parser.classes_per_it_val,
                n_support=parser.num_support_val)
            loss, acc = loss_fn2(support_input=support_output,
                                 query_input=query_output,
                                 targets=targets,
                                 n_support=parser.num_support_val)
            val_loss.append(loss.item())
            val_acc.append(acc.item())
        avg_loss = np.mean(val_loss[-parser.iterations:])
        avg_acc = 100. * np.mean(val_acc[-parser.iterations:])
        print('Validating Loss: {} | Accuracy: {}'.format(avg_loss, avg_acc))
        info = {'val_loss': avg_loss, 'val_accuracy': avg_acc}
        for tag, value in info.items():
            logger.scalar_summary(tag, value, epoch + 1)
        if avg_acc > best_acc:
            torch.save(model.state_dict(), best_model_path)
            best_acc = avg_acc
            best_state = model.state_dict()
    # Testing
    model.load_state_dict(best_state)
    test_acc = []
    for epoch in range(10):
        for batch_index, (inputs, targets,
                          targets_embedding) in enumerate(testset):
            support_samples, support_embeddings, query_samples = \
                splitInput(inputs, targets, targets_embedding, parser.num_support_tr)
            support_samples = support_samples.to(device)
            query_samples = query_samples.to(device)
            targets = targets.to(device)
            support_embeddings = support_embeddings.to(device)
            support_output, query_output, _ = model(
                x=support_samples,
                y=query_samples,
                x_emb=support_embeddings,
                n_class=parser.classes_per_it_val,
                n_support=parser.num_support_val)
            _, acc = loss_fn2(support_input=support_output,
                              query_input=query_output,
                              targets=targets,
                              n_support=parser.num_support_val)
            test_acc.append(acc.item())
    avg_acc = 100. * np.mean(test_acc)
    logger.scalar_summary('test_accuracy', avg_acc, 1)
    print('*****Testing Accuracy: {}'.format(avg_acc))
예제 #2
0
def main(parser, logger):
    print('--> Preparing Dataset:')
    trainset = dataloader(parser, 'train')
    valset = dataloader(parser, 'val')
    valset2 = dataloader2(parser, 'val')
    testset = dataloader(parser, 'test')
    testset2 = dataloader2(parser, 'test')
    print('--> Building Model:')
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model = Network().to(device)
    print('--> Initializing Optimizer and Scheduler')
    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=parser.learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer=optimizer,
        gamma=parser.lr_scheduler_gamma,
        step_size=parser.lr_scheduler_step)
    val_loss1 = []
    val_acc1 = []
    best_acc1 = 0
    best_state1 = model.state_dict()
    val_loss2 = []
    val_acc2 = []
    best_acc2 = 0
    best_state2 = model.state_dict()
    for epoch in range(parser.epochs):
        print('\nEpoch: %d' % epoch)
        # Training
        model.train()
        for batch_index, (inputs, targets) in enumerate(trainset):
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss, acc = loss_fn(input=output,
                                target=targets,
                                n_support=parser.num_support_tr)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        scheduler.step()
        # Validating
        model.eval()
        for batch_index, (inputs, targets) in enumerate(valset):
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss, acc = loss_fn(input=output,
                                target=targets,
                                n_support=parser.num_support_val)
            val_loss1.append(loss.item())
            val_acc1.append(acc.item())
        avg_loss = np.mean(val_loss1[-parser.iterations:])
        avg_acc = 100. * np.mean(val_acc1[-parser.iterations:])
        print('Validating1 Loss: {} | Accuracy1: {}'.format(avg_loss, avg_acc))
        info = {'val_loss1': avg_loss, 'val_accuracy1': avg_acc}
        for tag, value in info.items():
            logger.scalar_summary(tag, value, epoch + 1)
        if avg_acc > best_acc1:
            best_acc1 = avg_acc
            best_state1 = model.state_dict()
        for batch_index, (inputs, targets) in enumerate(valset2):
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss, acc = loss_fn(input=output,
                                target=targets,
                                n_support=parser.num_support_val2)
            val_loss2.append(loss.item())
            val_acc2.append(acc.item())
        avg_loss = np.mean(val_loss2[-parser.iterations:])
        avg_acc = 100. * np.mean(val_acc2[-parser.iterations:])
        print('Validating2 Loss: {} | Accuracy2: {}'.format(avg_loss, avg_acc))
        info = {'val_loss2': avg_loss, 'val_accuracy2': avg_acc}
        for tag, value in info.items():
            logger.scalar_summary(tag, value, epoch + 1)
        if avg_acc > best_acc2:
            best_acc2 = avg_acc
            best_state2 = model.state_dict()
    # Testing
    model.load_state_dict(best_state1)
    test_acc = []
    for epoch in range(10):
        for batch_index, (inputs, targets) in enumerate(testset):
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            _, acc = loss_fn(input=output,
                             target=targets,
                             n_support=parser.num_support_val)
            test_acc.append(acc.item())
    avg_acc = 100. * np.mean(test_acc)
    logger.scalar_summary('test_accuracy1', avg_acc, 1)
    print('*****Testing1 Accuracy: {}'.format(avg_acc))
    model.load_state_dict(best_state2)
    test_acc2 = []
    for epoch in range(10):
        for batch_index, (inputs, targets) in enumerate(testset2):
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            _, acc = loss_fn(input=output,
                             target=targets,
                             n_support=parser.num_support_val2)
            test_acc2.append(acc.item())
    avg_acc = 100. * np.mean(test_acc2)
    logger.scalar_summary('test_accuracy2', avg_acc, 1)
    print('*****Testing2 Accuracy: {}'.format(avg_acc))
예제 #3
0
def main(parser, logger):
    print('--> Preparing Dataset:')
    trainset = dataloader(parser, 'train')
    valset = dataloader(parser, 'val')
    testset = dataloader(parser, 'test')
    print('--> Building Model:')
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model = Network()
    model = nn.DataParallel(model, device_ids=[0,1,2,3])
    model = model.to(device)
    print('--> Initializing Optimizer and Scheduler')
    optimizer = torch.optim.Adam(params=model.parameters(), lr=parser.learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,
                                                gamma=parser.lr_scheduler_gamma,
                                                step_size=parser.lr_scheduler_step)
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    best_acc = 0
    best_model_path = os.path.join(parser.experiment_root, 'best_model.pth')
    best_state = model.state_dict()
    for epoch in range(parser.epochs):
        print('\nEpoch: %d' % epoch)
        # Training
        model.train()
        for batch_index, (inputs, targets) in enumerate(trainset):
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss, acc = loss_fn(input=output, target=targets, n_support=parser.num_support_tr)
            train_loss.append(loss.item())
            train_acc.append(acc.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        scheduler.step()
        avg_loss = np.mean(train_loss[-parser.iterations:])
        avg_acc = 100. * np.mean(train_acc[-parser.iterations:])
        print('Training Loss: {} | Accuracy: {}'.format(avg_loss, avg_acc))
        # Validating
        model.eval()
        for batch_index, (inputs, targets) in enumerate(valset):
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss, acc = loss_fn(input=output, target=targets, n_support=parser.num_support_val)
            val_loss.append(loss.item())
            val_acc.append(acc.item())
        avg_loss = np.mean(val_loss[-parser.iterations:])
        avg_acc = 100. * np.mean(val_acc[-parser.iterations:])
        print('Validating Loss: {} | Accuracy: {}'.format(avg_loss, avg_acc))
        info = {'val_loss': avg_loss, 'val_accuracy': avg_acc}
        for tag, value in info.items():
            logger.scalar_summary(tag, value, epoch + 1)
        if avg_acc > best_acc:
            torch.save(model.state_dict(), best_model_path)
            best_acc = avg_acc
            best_state = model.state_dict()
    # Testing
    model.load_state_dict(best_state)
    test_acc = []
    for epoch in range(10):
        for batch_index, (inputs, targets) in enumerate(testset):
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            _, acc = loss_fn(input=output, target=targets, n_support=parser.num_support_val)
            test_acc.append(acc.item())
    avg_acc = 100. * np.mean(test_acc)
    logger.scalar_summary('test_accuracy', avg_acc, 1)
    print('*****Testing Accuracy: {}'.format(avg_acc))