def train(args, model, trainloader, optimizer, epoch): model.train() for batch_idx, (data, target)in enumerate(trainloader): data = data.cuda() target = target.cuda() optimizer.zero_grad() #epsilon = np.random.uniform(0.01,0.04), epsilon = 4.0/255 perturb_steps = 4 # print(epsilon[0]) # print(type(epsilon[0])) perturb_steps = np.random.randint(1,10) # calculate robust loss loss = mart_loss(model=model, x_natural=data, y=target, optimizer=optimizer, step_size=args.step_size, epsilon= epsilon, perturb_steps= perturb_steps, beta=args.beta) loss.backward() optimizer.step() # print progress if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(trainloader.dataset), 100. * batch_idx / len(trainloader), loss.item()))
def train(args, model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() # calculate robust loss loss = mart_loss(model=model, x_natural=data, y=target, optimizer=optimizer, step_size=args.step_size, epsilon=args.epsilon, perturb_steps=args.num_steps, beta=args.beta) loss.backward() optimizer.step() # print progress if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()))