示例#1
0
    def load_network(loc, masked=False):
        net_checkpoint = torch.load(loc)
        start_epoch = net_checkpoint['epoch']
        SavedConv, SavedBlock = what_conv_block(net_checkpoint['conv'],
                                                net_checkpoint['blocktype'],
                                                net_checkpoint['module'])

        net = WideResNet(args.wrn_depth,
                         args.wrn_width,
                         SavedConv,
                         SavedBlock,
                         num_classes=num_classes,
                         dropRate=0,
                         masked=masked).cuda()

        if masked:
            new_sd = net.state_dict()
            old_sd = net_checkpoint['net']
            new_names = [v for v in new_sd]

            old_names = [v for v in old_sd]
            for i, j in enumerate(new_names):
                new_sd[j] = old_sd[old_names[i]]

            net.load_state_dict(new_sd)
        else:
            net.load_state_dict(net_checkpoint['net'])
        return net, start_epoch
def main(args):
    harakiri = Harakiri()
    harakiri.set_max_plateau(20)
    train_loss_meter = Meter()
    val_loss_meter = Meter()
    val_accuracy_meter = Meter()
    log = JsonLogger(args.log_path, rand_folder=True)
    log.update(args.__dict__)
    state = args.__dict__
    state['exp_dir'] = os.path.dirname(log.path)
    state['start_lr'] = state['lr']
    print(state)

    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]

    train_dataset = ImageList(args.root_folder,
                              args.train_listfile,
                              transform=transforms.Compose([
                                  transforms.Resize(256),
                                  transforms.RandomCrop(224),
                                  transforms.RandomHorizontalFlip(),
                                  transforms.ToTensor(),
                                  transforms.Normalize(imagenet_mean,
                                                       imagenet_std)
                              ]))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=False,
                                               num_workers=args.num_workers)
    val_dataset = ImageList(args.root_folder,
                            args.val_listfile,
                            transform=transforms.Compose([
                                transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize(imagenet_mean,
                                                     imagenet_std)
                            ]))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=False,
                                             num_workers=args.num_workers)

    if args.attention_depth == 0:
        from models.wide_resnet import WideResNet
        model = WideResNet().finetune(args.nlabels).cuda()
    else:
        from models.wide_resnet_attention import WideResNetAttention
        model = WideResNetAttention(args.nlabels, args.attention_depth,
                                    args.attention_width, args.has_gates,
                                    args.reg_weight).finetune(args.nlabels)

    # if args.load != "":
    #     net.load_state_dict(torch.load(args.load), strict=False)
    #     net = net.cuda()

    optimizer = optim.SGD([{
        'params': model.get_base_params(),
        'lr': args.lr * 0.1
    }, {
        'params': model.get_classifier_params()
    }],
                          lr=args.lr,
                          weight_decay=1e-4,
                          momentum=0.9,
                          nesterov=True)

    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, range(args.ngpu)).cuda()
    else:
        model = model.cuda()
        criterion = torch.nn.NLLLoss().cuda()

    def train():
        """

        """
        model.train()
        for data, label in train_loader:
            data, label = torch.autograd.Variable(data, requires_grad=False).cuda(async=True), \
                          torch.autograd.Variable(label, requires_grad=False).cuda()
            optimizer.zero_grad()
            if args.attention_depth > 0:
                output, loss = model(data)
                if args.reg_weight > 0:
                    loss = loss.mean()
                else:
                    loss = 0
            else:
                loss = 0
                output = model(data)
            loss += F.nll_loss(output, label)
            loss.backward()
            optimizer.step()
            train_loss_meter.update(loss.data[0], data.size(0))
        state['train_loss'] = train_loss_meter.mean()

    def val():
        """

        """
        model.eval()
        for data, label in val_loader:
            data, label = torch.autograd.Variable(data, volatile=True).cuda(async=True), \
                          torch.autograd.Variable(label, volatile=True).cuda()
            if args.attention_depth > 0:
                output, loss = model(data)
            else:
                output = model(data)
            loss = F.nll_loss(output, label)
            val_loss_meter.update(loss.data[0], data.size(0))
            preds = output.max(1)[1]
            val_accuracy_meter.update((preds == label).float().sum().data[0],
                                      data.size(0))
        state['val_loss'] = val_loss_meter.mean()
        state['val_accuracy'] = val_accuracy_meter.mean()

    best_accuracy = 0
    counter = 0
    for epoch in range(args.epochs):
        train()
        val()
        harakiri.update(epoch, state['val_accuracy'])
        if state['val_accuracy'] > best_accuracy:
            counter = 0
            best_accuracy = state['val_accuracy']
            if args.save:
                torch.save(model.state_dict(),
                           os.path.join(state["exp_dir"], "model.pytorch"))
        else:
            counter += 1
        state['epoch'] = epoch + 1
        log.update(state)
        print(state)
        if (epoch + 1) in args.schedule:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
            state['lr'] *= 0.1
            test_loss += loss.item()  # train_loss に結果を蓄積
    avg_test_loss = test_loss / len(test_loader)  # lossの平均を計算
    avg_test_acc = test_acc / len(test_loader.dataset)  # accの平均を計算

    # print log
    print('Epoch [{}/{}], train_loss: {loss:.8f}, train_acc: {train_acc:.4f}'.
          format(epoch + 1,
                 EPOCHS,
                 loss=avg_train_loss,
                 train_acc=avg_train_acc))
    print('Epoch [{}/{}], test_loss: {loss:.8f}, test_acc: {train_acc:.4f}'.
          format(epoch + 1, EPOCHS, loss=avg_test_loss,
                 train_acc=avg_test_acc))

    # append list for polt graph after training
    train_loss_list.append(avg_train_loss)
    train_acc_list.append(avg_train_acc)
    test_loss_list.append(avg_test_loss)
    test_acc_list.append(avg_test_acc)
    wandb.log({"epoch": epoch + 1, "train accuracy": avg_train_acc})
    wandb.log({"epoch": epoch + 1, "train accuracy": avg_train_acc})
    wandb.log({"epoch": epoch + 1, "test accuracy": avg_test_acc})
    wandb.log({"epoch": epoch + 1, "train loss": avg_train_loss})
    wandb.log({"epoch": epoch + 1, "test loss": avg_test_loss})

end_time = time.time()
print('elapsed time: {:.4f}'.format(end_time - start_time))

torch.save(model.state_dict(), "result/model_weight/" + run_name + '.pth')
wandb.save("result/model_weight/" + run_name + '.pth')