Ejemplo n.º 1
0
                    rho=args.rho,
                    lr=args.learning_rate,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, args.learning_rate, args.epochs)

    for epoch in range(args.epochs):
        model.train()
        log.train(len_dataset=len(dataset.train))

        for batch in dataset.train:
            inputs, targets = (b.to(device) for b in batch)

            # first forward-backward step
            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets)
            loss.mean().backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward step
            smooth_crossentropy(model(inputs), targets).mean().backward()
            optimizer.second_step(zero_grad=True)

            with torch.no_grad():
                correct = torch.argmax(predictions.data, 1) == targets
                log(model, loss.cpu(), correct.cpu(), scheduler.lr())
                scheduler(epoch)

        model.eval()
        log.eval(len_dataset=len(dataset.test))
Ejemplo n.º 2
0
    PATH = './trained_models/sam_net_250.pth'
    model.load_state_dict(torch.load(PATH))

    predict_all = np.array([])
    correct_all = np.array([], dtype=bool)
    targets_all = np.array([], dtype=int)
    with torch.no_grad():
        for batch in dataset.train:
            inputs, targets = (b.to(device) for b in batch)
            rands = torch.clone(targets)
            for r, i in zip(rands.data, range(128)):
                rands.data[i] = randint(0, 9)

            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, rands)
            correct = torch.argmax(predictions, 1) == rands
            x = targets.cpu().detach().numpy()
            targets_all = np.append(targets_all, x)
            x = predictions.cpu().detach().numpy()
            predict_all = np.append(predict_all, x)
            x = correct.cpu().detach().numpy()
            correct_all = np.append(correct_all, x)

    ##predict_xx = np.concatenate((predict_all, correct_all.T), axis=1)
    predict_all = np.reshape(predict_all, (50000, 10))
    ##targets_all = np.asarray(targets_all).reshape(-1)
    ##correct_all = np.asarray(correct_all).reshape(-1)
    targets_ = np.vstack((targets_all, correct_all)).T
    np.save("preprocessing/c100_train_targ.npy", targets_)
    np.save("preprocessing/c100_train_pred.npy", predict_all)
Ejemplo n.º 3
0
        optimizer = SGD(model.parameters(),
                        lr=args.learning_rate,
                        momentum=0.9,
                        nesterov=True,
                        weight_decay=args.weight_decay)

    scheduler = StepLR(optimizer, args.learning_rate, args.epochs)

    for epoch in range(args.epochs):
        model.train()
        log.train(len_dataset=len(dataset.train))

        for i, (inputs, labels) in enumerate(dataset.train):
            optimizer.zero_grad()
            outputs = model(inputs.to(device))
            loss = smooth_crossentropy(outputs, labels.to(device))
            loss.mean().backward(create_graph=args.optimizer == "ada_hessian")
            optimizer.step()

            with torch.no_grad():
                correct = (torch.argmax(outputs.data,
                                        1).cpu() == labels).float()
                log(model, loss.cpu(), correct, scheduler.lr())
                scheduler(epoch)

        model.eval()
        log.eval(len_dataset=len(dataset.test))

        with torch.no_grad():
            for inputs, labels in dataset.test:
                outputs = model(inputs.to(device))
Ejemplo n.º 4
0
    parser.add_argument("--width_factor",
                        default=8,
                        type=int,
                        help="How many times wider compared to normal ResNet.")
    args = parser.parse_args()

    initialize(args, seed=42)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    dataset = Cifar(args.batch_size, args.threads)
    model = WideResNet(args.depth,
                       args.width_factor,
                       args.dropout,
                       in_channels=3,
                       labels=10).to(device)

    PATH = './trained_models/sam_net_250.pth'
    model.load_state_dict(torch.load(PATH))

    with torch.no_grad():
        for batch in dataset.test:
            inputs, targets = (b.to(device) for b in batch)

            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets)
            correct = torch.argmax(predictions, 1) == targets
            for p in predictions:
                print(p)
            print()
            print(targets)
            print()
Ejemplo n.º 5
0
Archivo: train.py Proyecto: davda54/sam
    base_optimizer = torch.optim.SGD
    optimizer = SAM(model.parameters(), base_optimizer, rho=args.rho, adaptive=args.adaptive, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, args.learning_rate, args.epochs)

    for epoch in range(args.epochs):
        model.train()
        log.train(len_dataset=len(dataset.train))

        for batch in dataset.train:
            inputs, targets = (b.to(device) for b in batch)

            # first forward-backward step
            enable_running_stats(model)
            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets, smoothing=args.label_smoothing)
            loss.mean().backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward step
            disable_running_stats(model)
            smooth_crossentropy(model(inputs), targets, smoothing=args.label_smoothing).mean().backward()
            optimizer.second_step(zero_grad=True)

            with torch.no_grad():
                correct = torch.argmax(predictions.data, 1) == targets
                log(model, loss.cpu(), correct.cpu(), scheduler.lr())
                scheduler(epoch)

        model.eval()
        log.eval(len_dataset=len(dataset.test))