def process_function(unused_engine, batch): x, y = _prepare_batch(batch, device=device, non_blocking=True) model.train() optimizer.zero_grad() y_pred = model(x) if config['agreement_threshold'] > 0.0: # The "batch_size" in this function refers to the batch size per env # Since we treat every example as one env, we should set the parameter # n_agreement_envs equal to batch size mean_loss, masks = and_mask_utils.get_grads( agreement_threshold=config['agreement_threshold'], batch_size=1, loss_fn=criterion, n_agreement_envs=config['batch_size'], params=optimizer.param_groups[0]['params'], output=y_pred, target=y, method=args.method, scale_grad_inverse_sparsity=config[ 'scale_grad_inverse_sparsity'], ) else: mean_loss = criterion(y_pred, y) mean_loss.backward() optimizer.step() return {}
def train(model, device, train_loaders, optimizer, epoch, writer, scale_grad_inverse_sparsity, n_agreement_envs, loss_fn, l1_coef, method, agreement_threshold, scheduler, log_suffix=''): """n_agreement_envs is the number of envs used to compute agreements""" assert len( train_loaders ) % n_agreement_envs == 0 # Divisibility makes it more convenient model.train() losses = [] correct = 0 example_count = 0 batch_idx = 0 train_iterators = [iter(loader) for loader in train_loaders] it_groups = permutation_groups(train_iterators, n_agreement_envs) while 1: train_iterator_selection = next(it_groups) try: datas = [next(iterator) for iterator in train_iterator_selection] except StopIteration: break assert len(datas) == n_agreement_envs batch_size = datas[0][0].shape[0] assert all(d[0].shape[0] == batch_size for d in datas) inputs = [d[0].to(device) for d in datas] target = [d[1].to(device) for d in datas] inputs = torch.cat(inputs, dim=0) target = torch.cat(target, dim=0) optimizer.zero_grad() output = model(inputs) output = output.squeeze(1) validate_target_outupt_shapes(output, target) mean_loss, masks = get_grads( agreement_threshold, batch_size, loss_fn, n_agreement_envs, params=optimizer.param_groups[0]['params'], output=output, target=target, method=method, scale_grad_inverse_sparsity=scale_grad_inverse_sparsity, ) model.step += 1 if l1_coef > 0.0: add_l1_grads(l1_coef, optimizer.param_groups) optimizer.step() losses.append(mean_loss.item()) correct += count_correct(output, target) example_count += output.shape[0] batch_idx += 1 scheduler.step() # Logging train_loss = np.mean(losses) train_acc = correct / (example_count + 1e-10) writer.add_scalar(f'weight/norm', train_loss, epoch) writer.add_scalar(f'mean_loss/train{log_suffix}', train_loss, epoch) writer.add_scalar(f'acc/train{log_suffix}', train_acc, epoch) logger.info( f'Train Epoch: {epoch}\t Acc: {train_acc:.4} \tLoss: {train_loss:.6f}')
def train(model, args, device, train_loader, optimizer, epoch, writer, scale_grad_inverse_sparsity, loss_fn, method, agreement_threshold, scheduler, log_suffix=''): model.train() losses = [] correct = 0 example_count = 0 batch_idx = 0 for i, (images, labels) in enumerate(train_loader): images = images.reshape(-1, args.n_dims) optimizer.zero_grad() y_pred = model(images) if agreement_threshold > 0.0: # The "batch_size" in this function refers to the batch size per env # Since we treat every example as one env, we should set the parameter # n_agreement_envs equal to batch size mean_loss, masks = and_mask_utils.get_grads( agreement_threshold=agreement_threshold, batch_size=1, loss_fn=loss_fn, n_agreement_envs=args.batch_size, params=optimizer.param_groups[0]['params'], output=y_pred, target=labels, method=args.method, scale_grad_inverse_sparsity=scale_grad_inverse_sparsity, ) else: mean_loss = loss_fn(y_pred, labels) mean_loss.backward() mean_total_loss = 0 if args.l1_coef > 0.0: add_l1_grads(args.l1_coef, optimizer.param_groups) mean_total_loss += add_l1(args.l1_coef, optimizer.param_groups) if args.l2_coef > 0.0: add_l2_grads(args.l2_coef, optimizer.param_groups) mean_total_loss += add_l2(args.l2_coef, optimizer.param_groups) mean_total_loss += mean_loss.item() optimizer.step() losses.append(mean_total_loss) correct += count_correct(y_pred, labels) example_count += y_pred.shape[0] batch_idx += 1 scheduler.step() # Logging train_loss = np.mean(losses) train_acc = correct / (example_count + 1e-10) writer.add_scalar(f'weight/norm', train_loss, epoch) writer.add_scalar(f'mean_loss/train{log_suffix}', train_loss, epoch) writer.add_scalar(f'acc/train{log_suffix}', train_acc, epoch) logger.info( f'Train Epoch: {epoch}\t Acc: {train_acc:.4} \tLoss: {train_loss:.6f}')