def inference(model, itr, pop_idx, arch, val_loader, use_gpu, args): model.eval() all_prec1 = AverageMeter() all_prec5 = AverageMeter() with torch.no_grad(): for batch_idx, (img, label) in enumerate(val_loader): if use_gpu: img, label = img.cuda(), label.cuda() output = model(img, arch) prec1, prec5 = accuracy(output, label, topk=(1, 5)) if args.distributed: prec1 = reduce_tensor(prec1, args.world_size) prec5 = reduce_tensor(prec5, args.world_size) all_prec1.update(prec1.item(), img.size(0)) all_prec5.update(prec5.item(), img.size(0)) flops = get_flops(arch, args.flop_table) / 1e6 if args.local_rank == 0: logging.info('Iter: [{}/{}][{}/{}]\t' 'Arch: {}\t' 'FLOPs: {:.2f}M\t' 'Prec@1: {:.2f}%\t' 'Prec@5: {:.2f}%' .format(itr, args.total_search_iters, pop_idx + 1, args.pop_size, arch, flops, all_prec1.avg, all_prec5.avg)) return all_prec1.avg, all_prec5.avg, flops
def sync(self): rank = dist.get_rank() world_size = dist.get_world_size() val = torch.tensor(self.val).cuda() sum_v = torch.tensor(self.sum).cuda() count = torch.tensor(self.count).cuda() self.val = reduce_tensor(val, world_size).item() self.sum = reduce_tensor(sum_v, 1).item() self.count = reduce_tensor(count, 1).item() self.avg = self.sum / max(1, self.count)
def _forward(self, x, logpx=None): num_channels = x.size(-1) used_mean = self.running_mean.clone().detach() used_var = self.running_var.clone().detach() if self.training: # compute batch statistics x_t = x.transpose(0, 1).reshape(num_channels, -1) batch_mean = torch.mean(x_t, dim=1) if self.sync: batch_ex2 = torch.mean(x_t**2, dim=1) batch_mean = reduce_tensor(batch_mean) batch_ex2 = reduce_tensor(batch_ex2) batch_var = batch_ex2 - batch_mean**2 else: batch_var = torch.var(x_t, dim=1) # moving average if self.bn_lag > 0: used_mean = batch_mean - (1 - self.bn_lag) * ( batch_mean - used_mean.detach()) used_mean /= (1. - self.bn_lag**(self.step[0] + 1)) used_var = batch_var - (1 - self.bn_lag) * (batch_var - used_var.detach()) used_var /= (1. - self.bn_lag**(self.step[0] + 1)) # update running estimates self.running_mean -= self.decay * (self.running_mean - batch_mean.data) self.running_var -= self.decay * (self.running_var - batch_var.data) self.step += 1 # perform normalization used_mean = used_mean.view(*self.shape).expand_as(x) used_var = used_var.view(*self.shape).expand_as(x) y = (x - used_mean) * torch.exp(-0.5 * torch.log(used_var + self.eps)) if self.affine: weight = self.weight.view(*self.shape).expand_as(x) bias = self.bias.view(*self.shape).expand_as(x) y = y * torch.exp(weight) + bias if logpx is None: return y else: return y, logpx - self._logdetgrad(x, used_var).sum(-1, keepdim=True)
def ens_validate(val_loader, model, criterion, args, log, num_mc_samples=20, suffix=''): model.eval() ece_func = _ECELoss().cuda(args.gpu) with torch.no_grad(): targets = [] mis = [0 for _ in range(len(val_loader))] preds = [0 for _ in range(len(val_loader))] rets = torch.zeros(num_mc_samples, 9).cuda(args.gpu) for i, (input, target) in enumerate(val_loader): input = input.cuda(args.gpu, non_blocking=True) target = target.cuda(args.gpu, non_blocking=True) targets.append(target) for ens in range(num_mc_samples): output = model(input) one_loss = criterion(output, target) one_prec1, one_prec5 = accuracy(output, target, topk=(1, 5)) mis[i] = (mis[i] * ens + (-output.softmax(-1) * output.log_softmax(-1)).sum(1)) / (ens + 1) preds[i] = (preds[i] * ens + output.softmax(-1)) / (ens + 1) loss = criterion(preds[i].log(), target) prec1, prec5 = accuracy(preds[i], target, topk=(1, 5)) rets[ens, 0] += ens*target.size(0) rets[ens, 1] += one_loss.item()*target.size(0) rets[ens, 2] += one_prec1.item()*target.size(0) rets[ens, 3] += one_prec5.item()*target.size(0) rets[ens, 5] += loss.item()*target.size(0) rets[ens, 6] += prec1.item()*target.size(0) rets[ens, 7] += prec5.item()*target.size(0) preds = torch.cat(preds, 0) # to sync confidences, predictions = torch.max(preds, 1) targets = torch.cat(targets, 0) mis = (- preds * preds.log()).sum(1) - torch.cat(mis, 0) rets /= targets.size(0) if args.distributed: if suffix == '': confidences = dist_collect(confidences) predictions = dist_collect(predictions) targets = dist_collect(targets) mis = dist_collect(mis) rets = reduce_tensor(rets.data, args) rets = rets.data.cpu().numpy() if suffix == '': ens_ece = ece_func(confidences, predictions, targets, os.path.join(args.save_path, 'ens_cal{}.pdf'.format(suffix))) rets[-1, -1] = ens_ece if args.gpu == 0: np.save(os.path.join(args.save_path, 'mis{}.npy'.format(suffix)), mis.data.cpu().numpy()) return rets
def validate_single_class(config, data_loader, model): criterion = torch.nn.CrossEntropyLoss() model.eval() batch_time = AverageMeter() loss_meter = AverageMeter() acc1_meter = AverageMeter() acc5_meter = AverageMeter() end = time.time() for idx, (images, target) in enumerate(data_loader): images = images.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # compute output output = model(images) # measure accuracy and record loss loss = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) acc1 = reduce_tensor(acc1) acc5 = reduce_tensor(acc5) loss = reduce_tensor(loss) loss_meter.update(loss.item(), target.size(0)) acc1_meter.update(acc1.item(), target.size(0)) acc5_meter.update(acc5.item(), target.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) logger.info( f'Test: [{idx}/{len(data_loader)}]\t' f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' f'Mem {memory_used:.0f}MB') logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
def train_epoch(model, epoch, optim_tools, train_loader, use_gpu): model.train() optimizer, criterion, scheduler = optim_tools losses, train_time, data_time = [AverageMeter() for _ in range(3)] st_time = time.time() for batch_idx, (img, label) in enumerate(train_loader): data_time.update(time.time() - st_time) if use_gpu: img, label = img.cuda(), label.cuda() arch, flops = uniform_constraint_sampling(sum(args.num_layer_list), args.num_block_type, args.flop_table, args.local_rank) output = model(img, arch) loss = criterion(output, label) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() if not args.distributed: losses.update(loss.item(), img.size(0)) if use_gpu: torch.cuda.synchronize() train_time.update(time.time() - st_time) if batch_idx == 0 or (batch_idx + 1) % args.disp_freq == 0 or batch_idx + 1 == len( train_loader): if args.distributed: reduced_loss = reduce_tensor(loss.detach(), args.world_size) losses.update(reduced_loss.item(), img.size(0)) if args.local_rank == 0: lr = scheduler.get_lr()[0] logging.info( 'Epoch: [{}/{}][{}/{}]\t' 'LR: {:.2e}\t' 'Loss: {loss.val:.4f} ({loss.avg:.4f})\t' 'Train time: {train_time.val:.4f}s ({train_time.avg:.4f}s)\t' 'Load data time: {data_time.val:.4f}s ({data_time.avg:.4f}s)' .format(epoch, args.total_epochs, batch_idx + 1, len(train_loader), lr, loss=losses, train_time=train_time, data_time=data_time)) st_time = time.time()
def validate(model, loader, use_gpu): model.eval() all_prec1, all_prec5, val_time = [AverageMeter() for _ in range(3)] st_time = time.time() with torch.no_grad(): for batch_idx, (img, label) in enumerate(loader): if use_gpu: img, label = img.cuda(), label.cuda() output = model(img) prec1, prec5 = accuracy(output, label, topk=(1, 5)) if args.distributed: prec1 = reduce_tensor(prec1, args.world_size) prec5 = reduce_tensor(prec5, args.world_size) all_prec1.update(prec1.item(), img.size(0)) all_prec5.update(prec5.item(), img.size(0)) if use_gpu: torch.cuda.synchronize() val_time.update(time.time() - st_time) if args.local_rank == 0 and \ (batch_idx == 0 or (batch_idx + 1) % args.disp_freq == 0 or batch_idx + 1 == len(loader)): logging.info('Iter: [{}/{}]\t' 'Val time: {:.4f}s\t' 'Prec@1: {:.2f}%\t' 'Prec@5: {:.2f}%'.format(batch_idx + 1, len(loader), val_time.avg, all_prec1.avg, all_prec5.avg)) st_time = time.time() return all_prec1.avg, all_prec5.avg
def sync(self): buf = torch.tensor([self._sum, self._count], dtype=torch.float32).cuda() buf = reduce_tensor(buf, 1) _sum, _count = buf.tolist() _avg = _sum / max(1, _count) r = self._history_count / max(1, self._history_count + _count) self._history_avg = r * self._history_avg + (1.0 - r) * _avg self._history_count += _count self._sum = 0 self._count = 0 self._avg = None
def _step(self, inputs, input_sizes, targets): """ Make a single gradient update. This is called by train() and should not be called manually. Parameters ---------- inputs: inputs_sizes: targets: """ output = self.model(inputs, input_sizes) loss = self.criterion(output, targets.long()) loss = loss / inputs.size(0) # average the loss by minibatch if self.distributed: loss = loss.to(self.device) loss_value = reduce_tensor(loss, self.world_size).item() else: loss_value = loss.item() # Check to ensure valid loss was calculated valid_loss, error = check_loss(loss, loss_value) if valid_loss: self.optimizer.zero_grad() with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.max_norm) self.optimizer.step() else: print(error) print('Skipping grad update') loss_value = 0 return output, loss_value
def step(self, loss): if self.distributed: loss = loss.to(self.device) loss_value = reduce_tensor(loss, self.world_size).item() else: loss_value = loss.item() valid_loss, error = check_loss(loss, loss_value) if valid_loss: self.optimizer.zero_grad() with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.max_norm) self.optimizer.step() else: print(error) print('Skipping grad update') return False self.avg_loss += loss_value return True
def predict(model, path): with open(os.path.join(Path('data'), "dev.pkl"), "rb") as fin: x_dev, y_dev = pickle.load(fin) dev_examples = predict_processor.get_test_examples(path, x_dev, x_dev, size=-1) # print("测试数据量:{}".format(len(dev_examples))) # print("device:{}".format(device)) test_features = convert_examples_to_features( dev_examples, label_list, args.max_seq_length, tokenizer) logger.info("***** Running prediction *****") logger.info(" Num examples = %d", len(dev_examples)) logger.info(" Batch size = %d", args.eval_batch_size) all_input_ids = torch.tensor([f.input_ids for f in test_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in test_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in test_features], dtype=torch.long) all_label_ids = torch.tensor([f.label_ids for f in test_features], dtype=torch.long) test_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) # Run prediction for full data test_sampler = SequentialSampler(test_data) test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size) model.eval() pre_loss, pre_accuracy = 0, 0 nb_pre_steps, nb_pre_examples = 0, 0 for step, batch in enumerate(tqdm(test_dataloader, desc="Prediction Iteration")): input_ids, input_mask, segment_ids, label_ids = batch input_ids = input_ids.to(device) input_mask = input_mask.to(device) segment_ids = segment_ids.to(device) label_ids = label_ids.to(device) with torch.no_grad(): logits = model(input_ids, segment_ids, input_mask) loss_fct = CrossEntropyLoss().to(device) tmp_pre_loss = loss_fct(logits.view(-1, num_labels), label_ids.squeeze()) tmp_pre_accuracy = accuracy(logits.view(-1, num_labels).detach().cpu().numpy(), label_ids.squeeze().detach().cpu().numpy()) if args.local_rank != -1: tmp_pre_loss = reduce_tensor(tmp_pre_loss) tmp_pre_accuracy = reduce_tensor(torch.tensor(tmp_pre_accuracy).to(device)) pre_loss += tmp_pre_loss.mean().item() pre_accuracy += tmp_pre_accuracy.item() nb_pre_examples += input_ids.size(0) nb_pre_steps += 1 pre_loss = pre_loss / nb_pre_steps pre_accuracy = pre_accuracy / nb_pre_examples result = {'pre_loss': pre_loss, 'pre_accuracy': pre_accuracy} output_pre_file = os.path.join(args.output_dir, "pre_results.txt") with open(output_pre_file, "w") as writer: logger.info("***** Pre results *****") for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) return pre_loss, pre_accuracy
input_sizes = input_percentages.mul_(int(inputs.size(3))).int() # measure data loading time data_time.update(time.time() - end) inputs = inputs.to(device) out, output_sizes = model(inputs, input_sizes) out = out.transpose(0, 1) # TxNxH float_out = out.float() # ensure float32 for loss loss = criterion(float_out, targets, output_sizes, target_sizes).to(device) loss = loss / inputs.size(0) # average the loss by minibatch if args.distributed: loss = loss.to(device) loss_value = reduce_tensor(loss, args.world_size).item() data_time = reduce_tensor(data_time, args.world_size, reduce_op_max=True) else: loss_value = loss.item() # Check to ensure valid loss was calculated valid_loss, error = check_loss(loss, loss_value) if valid_loss: optimizer.zero_grad() # compute gradient if args.mixed_precision: optimizer.backward(loss) optimizer.clip_master_grads(args.max_norm) else:
def train(self): data_iter = iter(self.train_dataloader) if self.train_config.resume_checkpoint: start = self.resume_step + 1 else: start = 0 moving_max_grad = 0 moving_grad_moment = 0.999 max_grad = 0 for step in range(start, self.train_config.total_step + 1): try: image_dict = next(data_iter) except: data_iter = iter(self.train_dataloader) image_dict = next(data_iter) image, alpha, trimap, mask = image_dict['image'], image_dict[ 'alpha'], image_dict['trimap'], image_dict['mask'] image = image.cuda() alpha = alpha.cuda() trimap = trimap.cuda() mask = mask.cuda() fg_norm, bg_norm = image_dict['fg'].cuda(), image_dict['bg'].cuda() # train() of DistributedDataParallel has no return self.G.train() log_info = "" loss = 0 """===== Update Learning Rate =====""" if step < self.train_config.warmup_step and self.train_config.resume_checkpoint is None: cur_G_lr = utils.warmup_lr(self.train_config.G_lr, step + 1, self.train_config.warmup_step) utils.update_lr(cur_G_lr, self.G_optimizer) else: self.G_scheduler.step() cur_G_lr = self.G_scheduler.get_lr()[0] """===== Forward G =====""" pred = self.G(image, mask) alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = pred[ 'alpha_os1'], pred['alpha_os4'], pred['alpha_os8'] weight_os8 = utils.get_unknown_tensor(trimap) weight_os8[...] = 1 flag = False if step < self.train_config.warmup_step: flag = True weight_os4 = utils.get_unknown_tensor(trimap) weight_os1 = utils.get_unknown_tensor(trimap) elif step < self.train_config.warmup_step * 3: if random.randint(0, 1) == 0: flag = True weight_os4 = utils.get_unknown_tensor(trimap) weight_os1 = utils.get_unknown_tensor(trimap) else: weight_os4 = utils.get_unknown_tensor_from_pred( alpha_pred_os8, rand_width=CONFIG.model.self_refine_width1, train_mode=True) alpha_pred_os4[weight_os4 == 0] = alpha_pred_os8[weight_os4 == 0] weight_os1 = utils.get_unknown_tensor_from_pred( alpha_pred_os4, rand_width=CONFIG.model.self_refine_width2, train_mode=True) alpha_pred_os1[weight_os1 == 0] = alpha_pred_os4[weight_os1 == 0] else: weight_os4 = utils.get_unknown_tensor_from_pred( alpha_pred_os8, rand_width=CONFIG.model.self_refine_width1, train_mode=True) alpha_pred_os4[weight_os4 == 0] = alpha_pred_os8[weight_os4 == 0] weight_os1 = utils.get_unknown_tensor_from_pred( alpha_pred_os4, rand_width=CONFIG.model.self_refine_width2, train_mode=True) alpha_pred_os1[weight_os1 == 0] = alpha_pred_os4[weight_os1 == 0] """===== Calculate Loss =====""" if self.train_config.rec_weight > 0: self.loss_dict['rec'] = (self.regression_loss(alpha_pred_os1, alpha, loss_type='l1', weight=weight_os1) * 2 +\ self.regression_loss(alpha_pred_os4, alpha, loss_type='l1', weight=weight_os4) * 1 +\ self.regression_loss(alpha_pred_os8, alpha, loss_type='l1', weight=weight_os8) * 1) / 5.0 * self.train_config.rec_weight if self.train_config.comp_weight > 0: self.loss_dict['comp'] = (self.composition_loss(alpha_pred_os1, fg_norm, bg_norm, image, weight=weight_os1) * 2 +\ self.composition_loss(alpha_pred_os4, fg_norm, bg_norm, image, weight=weight_os4) * 1 +\ self.composition_loss(alpha_pred_os8, fg_norm, bg_norm, image, weight=weight_os8) * 1) / 5.0 * self.train_config.comp_weight if self.train_config.lap_weight > 0: self.loss_dict['lap'] = (self.lap_loss(logit=alpha_pred_os1, target=alpha, gauss_filter=self.gauss_filter, loss_type='l1', weight=weight_os1) * 2 +\ self.lap_loss(logit=alpha_pred_os4, target=alpha, gauss_filter=self.gauss_filter, loss_type='l1', weight=weight_os4) * 1 +\ self.lap_loss(logit=alpha_pred_os8, target=alpha, gauss_filter=self.gauss_filter, loss_type='l1', weight=weight_os8) * 1) / 5.0 * self.train_config.lap_weight for loss_key in self.loss_dict.keys(): if self.loss_dict[loss_key] is not None and loss_key in [ 'rec', 'comp', 'lap' ]: loss += self.loss_dict[loss_key] """===== Back Propagate =====""" self.reset_grad() loss.backward() """===== Clip Large Gradient =====""" if self.train_config.clip_grad: if moving_max_grad == 0: moving_max_grad = nn_utils.clip_grad_norm_( self.G.parameters(), 1e+6) max_grad = moving_max_grad else: max_grad = nn_utils.clip_grad_norm_( self.G.parameters(), 2 * moving_max_grad) moving_max_grad = moving_max_grad * moving_grad_moment + max_grad * ( 1 - moving_grad_moment) """===== Update Parameters =====""" self.G_optimizer.step() """===== Write Log and Tensorboard =====""" # stdout log if step % self.log_config.logging_step == 0: # reduce losses from GPUs if CONFIG.dist: self.loss_dict = utils.reduce_tensor_dict(self.loss_dict, mode='mean') loss = utils.reduce_tensor(loss) # create logging information for loss_key in self.loss_dict.keys(): if self.loss_dict[loss_key] is not None: log_info += loss_key.upper() + ": {:.4f}, ".format( self.loss_dict[loss_key]) self.logger.debug( "Image tensor shape: {}. Trimap tensor shape: {}".format( image.shape, trimap.shape)) log_info = "[{}/{}], ".format( step, self.train_config.total_step) + log_info log_info += "lr: {:6f}".format(cur_G_lr) self.logger.info(log_info) # tensorboard if step % self.log_config.tensorboard_step == 0 or step == start: # and step > start: self.tb_logger.scalar_summary('Loss', loss, step) # detailed losses for loss_key in self.loss_dict.keys(): if self.loss_dict[loss_key] is not None: self.tb_logger.scalar_summary( 'Loss_' + loss_key.upper(), self.loss_dict[loss_key], step) self.tb_logger.scalar_summary('LearnRate', cur_G_lr, step) if self.train_config.clip_grad: self.tb_logger.scalar_summary('Moving_Max_Grad', moving_max_grad, step) self.tb_logger.scalar_summary('Max_Grad', max_grad, step) """===== TEST =====""" if ((step % self.train_config.val_step) == 0 or step == self.train_config.total_step): # and step > start: self.G.eval() test_loss = 0 log_info = "" self.test_loss_dict['mse'] = 0 self.test_loss_dict['sad'] = 0 for loss_key in self.loss_dict.keys(): if loss_key in self.test_loss_dict and self.loss_dict[ loss_key] is not None: self.test_loss_dict[loss_key] = 0 with torch.no_grad(): for image_dict in self.test_dataloader: image, alpha, trimap, mask = image_dict[ 'image'], image_dict['alpha'], image_dict[ 'trimap'], image_dict['mask'] alpha_shape = image_dict['alpha_shape'] image = image.cuda() alpha = alpha.cuda() trimap = trimap.cuda() mask = mask.cuda() pred = self.G(image, mask) alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = pred[ 'alpha_os1'], pred['alpha_os4'], pred['alpha_os8'] alpha_pred = alpha_pred_os8.clone().detach() weight_os4 = utils.get_unknown_tensor_from_pred( alpha_pred, rand_width=CONFIG.model.self_refine_width1, train_mode=False) alpha_pred[weight_os4 > 0] = alpha_pred_os4[ weight_os4 > 0] weight_os1 = utils.get_unknown_tensor_from_pred( alpha_pred, rand_width=CONFIG.model.self_refine_width2, train_mode=False) alpha_pred[weight_os1 > 0] = alpha_pred_os1[ weight_os1 > 0] h, w = alpha_shape alpha_pred = alpha_pred[..., :h, :w] trimap = trimap[..., :h, :w] weight = utils.get_unknown_tensor(trimap) weight[...] = 1 # value of MSE/SAD here is different from test.py and matlab version self.test_loss_dict['mse'] += self.mse( alpha_pred, alpha, weight) self.test_loss_dict['sad'] += self.sad( alpha_pred, alpha, weight) if self.train_config.rec_weight > 0: self.test_loss_dict['rec'] += self.regression_loss(alpha_pred, alpha, weight=weight) \ * self.train_config.rec_weight # reduce losses from GPUs if CONFIG.dist: self.test_loss_dict = utils.reduce_tensor_dict( self.test_loss_dict, mode='mean') """===== Write Log and Tensorboard =====""" # stdout log for loss_key in self.test_loss_dict.keys(): if self.test_loss_dict[loss_key] is not None: self.test_loss_dict[loss_key] /= len( self.test_dataloader) # logging log_info += loss_key.upper() + ": {:.4f} ".format( self.test_loss_dict[loss_key]) self.tb_logger.scalar_summary( 'Loss_' + loss_key.upper(), self.test_loss_dict[loss_key], step, phase='test') if loss_key in ['rec']: test_loss += self.test_loss_dict[loss_key] self.logger.info("TEST: LOSS: {:.4f} ".format(test_loss) + log_info) self.tb_logger.scalar_summary('Loss', test_loss, step, phase='test') # if self.model_config.trimap_channel == 3: # trimap = trimap.argmax(dim=1, keepdim=True) # alpha_pred[trimap==2] = 1 # alpha_pred[trimap==0] = 0 image_set = { 'image': (utils.normalize_image(image[-1, ...]).data.cpu().numpy() * 255).astype(np.uint8), 'mask': (mask[-1, ...].data.cpu().numpy() * 255).astype(np.uint8), 'alpha': (alpha[-1, ...].data.cpu().numpy() * 255).astype(np.uint8), 'alpha_pred': (alpha_pred[-1, ...].data.cpu().numpy() * 255).astype(np.uint8) } self.tb_logger.image_summary(image_set, step, phase='test') """===== Save Model =====""" if (step % self.log_config.checkpoint_step == 0 or step == self.train_config.total_step) \ and CONFIG.local_rank == 0 and (step > start): self.logger.info( 'Saving the trained models from step {}...'.format( iter)) self.save_model("latest_model", step, loss) if self.test_loss_dict['mse'] < self.best_loss: self.best_loss = self.test_loss_dict['mse'] self.save_model("best_model", step, loss) torch.cuda.empty_cache()
def finetune(args, train_loader, test_loader, model, criterion): train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler labeled_loader = DataLoader(train_loader.dataset, sampler=train_sampler(train_loader.dataset), batch_size=args.finetune_batch_size, num_workers=args.workers, pin_memory=True) optimizer = optim.SGD(model.parameters(), lr=args.finetune_lr, momentum=args.finetune_momentum, weight_decay=args.finetune_weight_decay) scaler = amp.GradScaler(enabled=args.amp) logger.info("***** Running Finetuning *****") logger.info( f" Finetuning steps = {len(labeled_loader)*args.finetune_epochs}") for epoch in range(args.finetune_epochs): if args.world_size > 1: labeled_loader.sampler.set_epoch(epoch + 624) batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() model.train() end = time.time() labeled_iter = tqdm(labeled_loader, disable=args.local_rank not in [-1, 0]) for step, (images, targets) in enumerate(labeled_iter): data_time.update(time.time() - end) batch_size = targets.shape[0] images = images.to(args.device) targets = targets.to(args.device) with amp.autocast(enabled=args.amp): model.zero_grad() outputs = model(images) loss = criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() if args.world_size > 1: loss = reduce_tensor(loss.detach(), args.world_size) losses.update(loss.item(), batch_size) batch_time.update(time.time() - end) labeled_iter.set_description( f"Finetune Epoch: {epoch+1:2}/{args.finetune_epochs:2}. Data: {data_time.avg:.2f}s. " f"Batch: {batch_time.avg:.2f}s. Loss: {losses.avg:.4f}. ") labeled_iter.close() if args.local_rank in [-1, 0]: args.writer.add_scalar("finetune/train_loss", losses.avg, epoch) test_loss, top1, top5 = evaluate(args, test_loader, model, criterion) args.writer.add_scalar("finetune/test_loss", test_loss, epoch) args.writer.add_scalar("finetune/acc@1", top1, epoch) args.writer.add_scalar("finetune/acc@5", top5, epoch) is_best = top1 > args.best_top1 if is_best: args.best_top1 = top1 args.best_top5 = top5 logger.info(f"top-1 acc: {top1:.2f}") logger.info(f"Best top-1 acc: {args.best_top1:.2f}") save_checkpoint(args, { 'step': step + 1, 'best_top1': args.best_top1, 'best_top5': args.best_top5, 'student_state_dict': model.state_dict(), 'avg_state_dict': None, 'student_optimizer': optimizer.state_dict(), }, is_best, finetune=True) return
def forward(self, x, x_noisy, std_in, opt, step=None, writer=None, init=False, valid=False): opt.zero_grad() batch_size = x.size(0) num_points = x.size(1) z_mu, z_sigma = self.encoder(x) if self.use_deterministic_encoder: z = z_mu + 0 * z_sigma else: z = self.reparameterize_gaussian(z_mu, z_sigma) # Compute H[Q(z|X)] if self.use_deterministic_encoder: entropy = torch.zeros(batch_size).to(z) else: entropy = self.gaussian_entropy(z_sigma) # Compute the prior probability P(z) w, delta_log_pw = self.latent_glow(z) log_pw = standard_normal_logprob(w).view(batch_size, -1).sum(1, keepdim=True) delta_log_pw = delta_log_pw.view(batch_size, 1) log_pz = log_pw - delta_log_pw # Compute the reconstruction likelihood P(X|z) z_new = z.view(*z.size()) z_new = z_new + (log_pz * 0.).mean() y, delta_log_py = self.point_AF(x_noisy, std_in, z_new) log_py = standard_normal_logprob(y).view(batch_size, -1).sum(1, keepdim=True) delta_log_py = delta_log_py.view(batch_size, num_points, 1).sum(1) log_px = log_py - delta_log_py # Loss entropy_loss = -entropy.mean() recon_loss = -log_px.mean() prior_loss = -log_pz.mean() loss = entropy_loss + prior_loss + recon_loss if not init and not valid: loss.backward() opt.step() # LOGGING (after the training) if self.distributed: loss = reduce_tensor(loss.mean()) entropy_log = reduce_tensor(entropy.mean()) recon = reduce_tensor(-log_px.mean()) prior = reduce_tensor(-log_pz.mean()) else: loss = loss.mean() entropy_log = entropy.mean() recon = -log_px.mean() prior = -log_pz.mean() recon_nats = recon / float(x.size(1) * x.size(2)) prior_nats = prior / float(self.zdim) if writer is not None and not valid: writer.add_scalar('train/entropy', entropy_log, step) writer.add_scalar('train/prior', prior, step) writer.add_scalar('train/prior(nats)', prior_nats, step) writer.add_scalar('train/recon', recon, step) writer.add_scalar('train/recon(nats)', recon_nats, step) writer.add_scalar('train/loss', loss.item(), step) return { 'entropy': entropy_log.cpu().detach().item() if not isinstance(entropy_log, float) else entropy_log, 'prior_nats': prior_nats, 'recon_nats': recon_nats, 'prior': prior, 'recon': recon, 'loss': loss.item() }
def train_loop(args, labeled_loader, unlabeled_loader, test_loader, teacher_model, student_model, avg_student_model, criterion, t_optimizer, s_optimizer, t_scheduler, s_scheduler, t_scaler, s_scaler): logger.info("***** Running Training *****") logger.info(f" Task = {args.dataset}@{args.num_labeled}") logger.info(f" Total steps = {args.total_steps}") if args.world_size > 1: labeled_epoch = 0 unlabeled_epoch = 0 labeled_loader.sampler.set_epoch(labeled_epoch) unlabeled_loader.sampler.set_epoch(unlabeled_epoch) labeled_iter = iter(labeled_loader) unlabeled_iter = iter(unlabeled_loader) moving_dot_product = torch.empty(1).to(args.device) limit = 3.0**(0.5) # 3 = 6 / (f_in + f_out) nn.init.uniform_(moving_dot_product, -limit, limit) for step in range(args.start_step, args.total_steps): if step % args.eval_step == 0: pbar = tqdm(range(args.eval_step), disable=args.local_rank not in [-1, 0]) batch_time = AverageMeter() data_time = AverageMeter() s_losses = AverageMeter() t_losses = AverageMeter() t_losses_l = AverageMeter() t_losses_u = AverageMeter() t_losses_mpl = AverageMeter() mean_mask = AverageMeter() teacher_model.train() student_model.train() end = time.time() try: images_l, targets = labeled_iter.next() except: if args.world_size > 1: labeled_epoch += 1 labeled_loader.sampler.set_epoch(labeled_epoch) labeled_iter = iter(labeled_loader) images_l, targets = labeled_iter.next() try: (images_uw, images_us), _ = unlabeled_iter.next() except: if args.world_size > 1: unlabeled_epoch += 1 unlabeled_loader.sampler.set_epoch(unlabeled_epoch) unlabeled_iter = iter(unlabeled_loader) (images_uw, images_us), _ = unlabeled_iter.next() data_time.update(time.time() - end) images_l = images_l.to(args.device) images_uw = images_uw.to(args.device) images_us = images_us.to(args.device) targets = targets.to(args.device) with amp.autocast(enabled=args.amp): batch_size = images_l.shape[0] t_images = torch.cat((images_l, images_uw, images_us)) t_logits = teacher_model(t_images) t_logits_l = t_logits[:batch_size] t_logits_uw, t_logits_us = t_logits[batch_size:].chunk(2) del t_logits t_loss_l = criterion(t_logits_l, targets) soft_pseudo_label = torch.softmax(t_logits_uw.detach() / args.temperature, dim=-1) max_probs, hard_pseudo_label = torch.max(soft_pseudo_label, dim=-1) mask = max_probs.ge(args.threshold).float() t_loss_u = torch.mean( -(soft_pseudo_label * torch.log_softmax(t_logits_us, dim=-1)).sum(dim=-1) * mask) weight_u = args.lambda_u * min(1., (step + 1) / args.uda_steps) t_loss_uda = t_loss_l + weight_u * t_loss_u # s_images = torch.cat((images_l, images_us)) # s_logits = student_model(s_images) # s_logits_l = s_logits[:batch_size] # s_logits_us = s_logits[batch_size:] s_logits_us = student_model(images_us) student_model.eval() with torch.no_grad(): s_logits_l = student_model(images_l) student_model.train() # del s_logits s_loss_l_old = F.cross_entropy(s_logits_l.detach(), targets) s_loss = criterion(s_logits_us, hard_pseudo_label) s_scaler.scale(s_loss).backward() if args.grad_clip > 0: s_scaler.unscale_(s_optimizer) nn.utils.clip_grad_norm_(student_model.parameters(), args.grad_clip) s_scaler.step(s_optimizer) s_scaler.update() s_scheduler.step() if args.ema > 0: avg_student_model.update_parameters(student_model) with amp.autocast(enabled=args.amp): student_model.eval() with torch.no_grad(): s_logits_l = student_model(images_l) student_model.train() s_loss_l_new = F.cross_entropy(s_logits_l.detach(), targets) dot_product = s_loss_l_new - s_loss_l_old # test # dot_product = s_loss_l_old - s_loss_l_new moving_dot_product = moving_dot_product * 0.99 + dot_product * 0.01 dot_product = dot_product - moving_dot_product _, hard_pseudo_label = torch.max(t_logits_us.detach(), dim=-1) t_loss_mpl = dot_product * F.cross_entropy(t_logits_us, hard_pseudo_label) t_loss = t_loss_uda + t_loss_mpl t_scaler.scale(t_loss).backward() if args.grad_clip > 0: t_scaler.unscale_(t_optimizer) nn.utils.clip_grad_norm_(teacher_model.parameters(), args.grad_clip) t_scaler.step(t_optimizer) t_scaler.update() t_scheduler.step() teacher_model.zero_grad() student_model.zero_grad() if args.world_size > 1: s_loss = reduce_tensor(s_loss.detach(), args.world_size) t_loss = reduce_tensor(t_loss.detach(), args.world_size) t_loss_l = reduce_tensor(t_loss_l.detach(), args.world_size) t_loss_u = reduce_tensor(t_loss_u.detach(), args.world_size) t_loss_mpl = reduce_tensor(t_loss_mpl.detach(), args.world_size) mask = reduce_tensor(mask, args.world_size) s_losses.update(s_loss.item()) t_losses.update(t_loss.item()) t_losses_l.update(t_loss_l.item()) t_losses_u.update(t_loss_u.item()) t_losses_mpl.update(t_loss_mpl.item()) mean_mask.update(mask.mean().item()) batch_time.update(time.time() - end) pbar.set_description( f"Train Iter: {step+1:3}/{args.total_steps:3}. " f"LR: {get_lr(s_optimizer):.4f}. Data: {data_time.avg:.2f}s. " f"Batch: {batch_time.avg:.2f}s. S_Loss: {s_losses.avg:.4f}. " f"T_Loss: {t_losses.avg:.4f}. Mask: {mean_mask.avg:.4f}. ") pbar.update() if args.local_rank in [-1, 0]: args.writer.add_scalar("lr", get_lr(s_optimizer), step) args.num_eval = step // args.eval_step if (step + 1) % args.eval_step == 0: pbar.close() if args.local_rank in [-1, 0]: args.writer.add_scalar("train/1.s_loss", s_losses.avg, args.num_eval) args.writer.add_scalar("train/2.t_loss", t_losses.avg, args.num_eval) args.writer.add_scalar("train/3.t_labeled", t_losses_l.avg, args.num_eval) args.writer.add_scalar("train/4.t_unlabeled", t_losses_u.avg, args.num_eval) args.writer.add_scalar("train/5.t_mpl", t_losses_mpl.avg, args.num_eval) args.writer.add_scalar("train/6.mask", mean_mask.avg, args.num_eval) test_model = avg_student_model if avg_student_model is not None else student_model test_loss, top1, top5 = evaluate(args, test_loader, test_model, criterion) args.writer.add_scalar("test/loss", test_loss, args.num_eval) args.writer.add_scalar("test/acc@1", top1, args.num_eval) args.writer.add_scalar("test/acc@5", top5, args.num_eval) is_best = top1 > args.best_top1 if is_best: args.best_top1 = top1 args.best_top5 = top5 logger.info(f"top-1 acc: {top1:.2f}") logger.info(f"Best top-1 acc: {args.best_top1:.2f}") save_checkpoint( args, { 'step': step + 1, 'teacher_state_dict': teacher_model.state_dict(), 'student_state_dict': student_model.state_dict(), 'avg_state_dict': avg_student_model.state_dict() if avg_student_model is not None else None, 'best_top1': args.best_top1, 'best_top5': args.best_top5, 'teacher_optimizer': t_optimizer.state_dict(), 'student_optimizer': s_optimizer.state_dict(), 'teacher_scheduler': t_scheduler.state_dict(), 'student_scheduler': s_scheduler.state_dict(), 'teacher_scaler': t_scaler.state_dict(), 'student_scaler': s_scaler.state_dict(), }, is_best) # finetune del t_scaler, t_scheduler, t_optimizer, teacher_model, unlabeled_loader del s_scaler, s_scheduler, s_optimizer ckpt_name = f'{args.save_path}/{args.name}_best.pth.tar' loc = f'cuda:{args.gpu}' checkpoint = torch.load(ckpt_name, map_location=loc) logger.info(f"=> loading checkpoint '{ckpt_name}'") if checkpoint['avg_state_dict'] is not None: model_load_state_dict(student_model, checkpoint['avg_state_dict']) else: model_load_state_dict(student_model, checkpoint['student_state_dict']) finetune(args, labeled_loader, test_loader, student_model, criterion) return
def fit(num_epoch=args['num_train_epochs']): global_step = 0 model.train() for i_ in tqdm(range(int(num_epoch)), desc="Epoch"): print('当前阶段******************************', i_) tr_loss, tr_accuracy = 0, 0 nb_tr_examples, nb_tr_steps = 0, 0 for index, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch try: logits = model(input_ids, segment_ids, input_mask, label_ids) tmp_train_loss = loss_fct(logits.view(-1, num_labels), label_ids.squeeze()) tmp_train_accuracy = accuracy( logits.view(-1, num_labels).detach().cpu().numpy(), label_ids.squeeze().detach().cpu().numpy()) if n_gpu > 1: tmp_train_loss = tmp_train_loss.mean( ) # mean() to average on multi-gpu. if args["local_rank"] != -1: tmp_train_loss = reduce_tensor(tmp_train_loss) tmp_train_accuracy = reduce_tensor( torch.tensor(tmp_train_accuracy).to(device)) tmp_train_loss = tmp_train_loss / args[ 'gradient_accumulation_steps'] with amp.scale_loss(tmp_train_loss, optimizer) as scaled_loss: scaled_loss.backward() # if args['fp16']: # optimizer.backward(tmp_train_loss) # else: # tmp_train_loss.backward() if (index + 1) % args['gradient_accumulation_steps'] == 0: optimizer.step() optimizer.zero_grad() tr_loss += tmp_train_loss.item() tr_accuracy += tmp_train_accuracy.item() nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 global_step += 1 except RuntimeError as e: if 'out of memory' in str(e): print('| WARNING: ran out of memory') if hasattr(torch.cuda, 'empty_cache'): torch.cuda.empty_cache() else: raise e # Tensorboard Logging eval_loss, eval_accuracy = 0, 0 if global_step % 100 == 0: eval_loss, eval_accuracy = eval() logger.info('tr_loss:{} & tr_accuracy:{}'.format( tr_loss / nb_tr_steps, tr_accuracy / nb_tr_examples)) logger.info('eval_loss:{} & eval_accuracy:{}'.format( eval_loss, eval_accuracy)) info = { 'tr_loss': tr_loss / nb_tr_steps, 'tr_accuracy': tr_accuracy / nb_tr_examples } for tag, value in info.items(): loggers.scalar_summary(tag, value, global_step + 1) info = {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy} for tag, value in info.items(): loggers.scalar_summary(tag, value, global_step + 1) # 将模型保存下来 if global_step % 200 == 0: params.append(eval_accuracy) if eval_accuracy >= max(params): if args["local_rank"] == -1: model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join( model_path, "finetuned_pytorch_model.bin") torch.save(model_to_save.state_dict(), output_model_file) elif args["local_rank"] == 0: checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'amp': amp.state_dict() } output_model_file = os.path.join( model_path, "amp_checkpoint.pt") torch.save(checkpoint, output_model_file) # model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self # output_model_file = os.path.join(model_path, "checkpoint.pt") # torch.save({ # 'model': model_to_save.state_dict() # }, output_model_file) if args["fp16"]: # scheduler.batch_step() # modify learning rate with special warm up BERT uses lr_this_step = args['learning_rate'] * warmup_linear( global_step / t_total, args['warmup_proportion']) for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step else: scheduler.step()
def ens_attack(val_loader, model, criterion, args, log, num_mc_samples=20): def _grad(X, y, mean, std): probs = torch.zeros(num_mc_samples, X.shape[0]).cuda(args.gpu) grads = torch.zeros(num_mc_samples, *list(X.shape)).cuda(args.gpu) for j in range(num_mc_samples): with model.no_sync(): with torch.enable_grad(): X.requires_grad_() output = model(X.sub(mean).div(std)) loss = torch.nn.functional.cross_entropy(output, y, reduction='none') grad_ = torch.autograd.grad( [loss], [X], grad_outputs=torch.ones_like(loss), retain_graph=False)[0].detach() grads[j] = grad_ probs[j] = torch.gather(output.detach().softmax(-1), 1, y[:, None]).squeeze() probs /= probs.sum(0) grad_ = (grads * probs[:, :, None, None, None]).sum(0) return grad_ def _pgd_whitebox(X, y, mean, std): X_pgd = X.clone() if args.random: X_pgd += torch.cuda.FloatTensor(*X_pgd.shape).uniform_( -args.epsilon, args.epsilon) for _ in range(args.num_steps): grad_ = _grad(X_pgd, y, mean, std) X_pgd += args.step_size * grad_.sign() eta = torch.clamp(X_pgd - X, -args.epsilon, args.epsilon) X_pgd = torch.clamp(X + eta, 0, 1.0) mis = 0 preds = 0 for ens in range(num_mc_samples): output = model(X_pgd.sub(mean).div(std)) mis = (mis * ens + (-output.softmax(-1) * (output).log_softmax(-1)).sum(1)) / (ens + 1) preds = (preds * ens + output.softmax(-1)) / (ens + 1) loss = criterion((preds + 1e-8).log(), target) prec1, prec5 = accuracy(preds, target, topk=(1, 5)) mis = (-preds * (preds + 1e-8).log()).sum(1) - mis return loss, prec1, prec5, mis if args.dataset == 'cifar10': mean = torch.from_numpy( np.array([x / 255 for x in [125.3, 123.0, 113.9] ])).view(1, 3, 1, 1).cuda(args.gpu).float() std = torch.from_numpy(np.array([x / 255 for x in [63.0, 62.1, 66.7] ])).view(1, 3, 1, 1).cuda(args.gpu).float() elif args.dataset == 'cifar100': mean = torch.from_numpy( np.array([x / 255 for x in [129.3, 124.1, 112.4] ])).view(1, 3, 1, 1).cuda(args.gpu).float() std = torch.from_numpy(np.array([x / 255 for x in [68.2, 65.4, 70.4] ])).view(1, 3, 1, 1).cuda(args.gpu).float() elif args.dataset == 'imagenet': mean = torch.from_numpy(np.array([0.485, 0.456, 0.406 ])).view(1, 3, 1, 1).cuda(args.gpu).float() std = torch.from_numpy(np.array([0.229, 0.224, 0.225 ])).view(1, 3, 1, 1).cuda(args.gpu).float() losses, top1, top5 = 0, 0, 0 model.eval() with torch.no_grad(): mis = [] for i, (input, target) in enumerate(val_loader): input = input.cuda(args.gpu, non_blocking=True).mul_(std).add_(mean) target = target.cuda(args.gpu, non_blocking=True) loss, prec1, prec5, mis_ = _pgd_whitebox(input, target, mean, std) losses += loss * target.size(0) top1 += prec1 * target.size(0) top5 += prec5 * target.size(0) mis.append(mis_) mis = torch.cat(mis, 0) losses /= mis.size(0) top1 /= mis.size(0) top5 /= mis.size(0) losses = reduce_tensor(losses.data, args) top1 = reduce_tensor(top1.data, args) top5 = reduce_tensor(top5.data, args) if args.distributed: mis = dist_collect(mis) print_log( 'ADV ensemble TOP1: {:.4f}, TOP5: {:.4f}, LOS: {:.4f}'.format( top1.item(), top5.item(), losses.item()), log) if args.gpu == 0: np.save(os.path.join(args.save_path, 'mis_advg.npy'), mis.data.cpu().numpy())
def eval(): args['output_dir'].mkdir(exist_ok=True) eval_features = convert_examples_to_features(eval_examples, label_list, args['max_seq_length'], tokenizer) logger.info("***** Running evaluation *****") logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Batch size = %d", args['eval_batch_size']) all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) all_label_ids = torch.tensor([f.label_ids for f in eval_features], dtype=torch.long) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) # Run prediction for full data eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args['eval_batch_size']) model.eval() eval_loss, eval_accuracy = 0, 0 nb_eval_steps, nb_eval_examples = 0, 0 for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: input_ids = input_ids.to(device) input_mask = input_mask.to(device) segment_ids = segment_ids.to(device) label_ids = label_ids.to(device) print("device:{}".format(device)) with torch.no_grad(): # tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids) logits = model(input_ids, segment_ids, input_mask) tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.squeeze()) tmp_eval_accuracy = accuracy( logits.view(-1, num_labels).detach().cpu().numpy(), label_ids.squeeze().detach().cpu().numpy()) if args["local_rank"] != -1: tmp_eval_loss = reduce_tensor(tmp_eval_loss) tmp_eval_accuracy = reduce_tensor( torch.tensor(tmp_eval_accuracy).to(device)) eval_loss += tmp_eval_loss.mean().item() eval_accuracy += tmp_eval_accuracy.item() nb_eval_examples += input_ids.size(0) nb_eval_steps += 1 eval_loss = eval_loss / nb_eval_steps eval_accuracy = eval_accuracy / nb_eval_examples result = {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy} output_eval_file = os.path.join(args['output_dir'], "eval_results.txt") with open(output_eval_file, "w") as writer: logger.info("***** Eval results *****") for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) return eval_loss, eval_accuracy
data_time.update(time.time() - end) inputs = inputs.to(device) out, output_sizes = model(inputs, input_sizes) out = out.transpose(0, 1) # TxNxH float_out = out.float() # ensure float32 for loss #print(float_out.to('cpu')) #break loss = criterion(float_out.to('cpu'), targets, output_sizes, target_sizes).to(device) loss = loss / inputs.size(0) # average the loss by minibatchi if args.distributed: loss = loss.to(device) loss_value = reduce_tensor(loss, args.world_size).item() else: loss_value = loss.item() # Check to ensure valid loss was calculated valid_loss, error = check_loss(loss, loss_value) if valid_loss: optimizer.zero_grad() # compute gradient with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) optimizer.step() #if i%16 == 15:
def forward(self, x, step, writer=None): # x is (n, l, c) batch_size = x.size(0) num_points = x.size(1) z_mu, z_sigma = self.encoder(x) # assume z_sigma is ln(sigma) if self.use_deterministic_encoder: z = z_mu + 0 * z_sigma else: z = self.reparameterize_gaussian(z_mu, z_sigma) # Compute H[Q(z|X)] if self.use_deterministic_encoder: entropy = torch.zeros(batch_size).to(z) else: entropy = self.gaussian_entropy(z_sigma) # Compute the prior probability P(z) if self.use_latent_flow: w, delta_log_pw = self.latent_rsf(z, torch.zeros(batch_size, 1).to(z)) log_pw = standard_normal_logprob(w).view(batch_size, -1).sum(1, keepdim=True) delta_log_pw = delta_log_pw.view(batch_size, 1) log_pz = log_pw - delta_log_pw else: log_pz = torch.zeros(batch_size, 1).to(z) # Compute the reconstruction likelihood P(X|z) # z_new = z.view(*z.size()) # z_new = z_new + (log_pz * 0.).mean() y, delta_log_py = self.point_rsf( x, torch.zeros(batch_size, num_points, 1).to(x)) log_py = standard_normal_logprob(y).view(batch_size, -1).sum(1, keepdim=True) delta_log_py = delta_log_py.view(batch_size, num_points, 1).sum(1) log_px = log_py - delta_log_py # Loss entropy_loss = -entropy.mean() * self.entropy_weight recon_loss = -log_px.mean() * self.recon_weight prior_loss = -log_pz.mean() * self.prior_weight loss = entropy_loss + prior_loss + recon_loss # LOGGING (after the training) if self.distributed: entropy_log = reduce_tensor(entropy.mean()) recon = reduce_tensor(-log_px.mean()) prior = reduce_tensor(-log_pz.mean()) else: entropy_log = entropy.mean() recon = -log_px.mean() prior = -log_pz.mean() recon_nats = recon / float(x.size(1) * x.size(2)) prior_nats = prior / float(self.zdim) if writer is not None: writer.add_scalar('train/entropy', entropy_log, step) writer.add_scalar('train/prior', prior, step) writer.add_scalar('train/prior(nats)', prior_nats, step) writer.add_scalar('train/recon', recon, step) writer.add_scalar('train/recon(nats)', recon_nats, step) return { 'entropy': entropy_log.cpu().detach().item() if not isinstance(entropy_log, float) else entropy_log, 'prior_nats': prior_nats, 'recon_nats': recon_nats, }, loss