def train_val(config): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') train_loader = get_dataloader(img_dir=config.train_img_dir, mask_dir=config.train_mask_dir, mode="train", batch_size=config.batch_size, num_workers=config.num_workers) val_loader = get_dataloader(img_dir=config.val_img_dir, mask_dir=config.val_mask_dir, mode="val", batch_size=config.batch_size, num_workers=config.num_workers) writer = SummaryWriter( comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" % (config.lr, config.batch_size, config.model_type, config.data_type)) if config.model_type not in [ 'UNet', 'R2UNet', 'AUNet', 'R2AUNet', 'SEUNet', 'SEUNet++', 'UNet++', 'DAUNet', 'DANet', 'AUNetR', 'RendDANet', "RendUNet" ]: print('ERROR!! model_type should be selected in supported models') print('Choose model %s' % config.model_type) return if config.model_type == "UNet": model = UNet() elif config.model_type == "AUNet": model = AUNet() elif config.model_type == "R2UNet": model = R2UNet() elif config.model_type == "SEUNet": model = SEUNet(useCSE=False, useSSE=False, useCSSE=True) elif config.model_type == "UNet++": model = UNetPP() elif config.model_type == "DANet": model = DANet(backbone='resnet101', nclass=config.output_ch) elif config.model_type == "AUNetR": model = AUNet_R16(n_classes=1, learned_bilinear=True) elif config.model_type == "RendDANet": model = RendDANet(backbone='resnet101', nclass=config.output_ch) elif config.model_type == "RendUNet": model = RendUNet() else: model = UNet() if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") model = nn.DataParallel(model) model = model.to(device) print('# parameters:', sum(param.numel() for param in model.parameters())) if config.optimizer == "sgd": optimizer = SGD(model.seg.parameters(), lr=1e-2, weight_decay=1e-6, momentum=0.9) else: optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) if config.loss == "dice": criterion = DiceLoss() elif config.loss == "bce": criterion = nn.BCELoss() elif config.loss == "mix": criterion = MixLoss() else: criterion = MultiRendLoss_v10() scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) global_step = 0 best_dice = 0.0 for epoch in range(config.num_epochs): epoch_loss = 0.0 with tqdm(total=config.num_train, desc="Epoch %d / %d" % (epoch + 1, config.num_epochs), unit='img') as train_pbar: model.train() for image, mask in train_loader: image = image.to(device, dtype=torch.float32) mask = mask.to(device, dtype=torch.float32) output = model(image) loss = criterion(output, mask) epoch_loss += loss.item() writer.add_scalar('Loss/train', loss.item(), global_step) train_pbar.set_postfix(**{'loss (batch)': loss.item()}) optimizer.zero_grad() loss.backward() optimizer.step() train_pbar.update(image.shape[0]) global_step += 1 scheduler.step() epoch_dice = 0.0 epoch_acc = 0.0 epoch_sen = 0.0 epoch_spe = 0.0 epoch_pre = 0.0 current_num = 0 with tqdm(total=config.num_val, desc="Epoch %d / %d validation round" % (epoch + 1, config.num_epochs), unit='img') as val_pbar: model.eval() locker = 0 for image, mask in val_loader: current_num += image.shape[0] image = image.to(device, dtype=torch.float32) mask = mask.to(device, dtype=torch.float32) output = model(image) pred = torch.sigmoid(output['fine']) batch_dice = dice_coeff(mask, pred).item() epoch_dice += batch_dice * image.shape[0] epoch_acc += get_accuracy(pred=pred, true=mask) * image.shape[0] epoch_sen += get_sensitivity(pred=pred, true=mask) * image.shape[0] epoch_spe += get_specificity(pred=pred, true=mask) * image.shape[0] epoch_pre += get_precision(pred=pred, true=mask) * image.shape[0] if locker == 200: writer.add_images('masks/true', mask, epoch + 1) writer.add_images('masks/pred', pred > 0.5, epoch + 1) val_pbar.set_postfix(**{'dice (batch)': batch_dice}) val_pbar.update(image.shape[0]) locker += 1 epoch_dice /= float(current_num) epoch_acc /= float(current_num) epoch_sen /= float(current_num) epoch_spe /= float(current_num) epoch_pre /= float(current_num) epoch_f1 = get_F1(SE=epoch_sen, PR=epoch_pre) if epoch_dice > best_dice: best_dice = epoch_dice writer.add_scalar('Best Dice/test', best_dice, epoch + 1) torch.save( model, config.result_path + "/%s_%s_%d.pth" % (config.model_type, str(epoch_dice), epoch + 1)) logging.info('Validation Dice Coeff: {}'.format(epoch_dice)) print("epoch dice: " + str(epoch_dice)) writer.add_scalar('Dice/test', epoch_dice, epoch + 1) writer.add_scalar('Acc/test', epoch_acc, epoch + 1) writer.add_scalar('Sen/test', epoch_sen, epoch + 1) writer.add_scalar('Spe/test', epoch_spe, epoch + 1) writer.add_scalar('Pre/test', epoch_pre, epoch + 1) writer.add_scalar('F1/test', epoch_f1, epoch + 1) writer.close() print("Training finished")
def train(self, epoch, hparam=None): ''' Inputs: - hparam: dictionary of hyperparameters. Save average epoch loss, train and validation accuracy to tensorboard. After half of the training epoch, save model, optimizer ,and scalar state dict, current epoch, stats and config during checkpoint ''' model_start_time = self.config['model_start_time'] previous_epoch = self.config['previous_epoch'] check_every_epoch = self.config['check_every_epoch'] self.config['hparam'] = hparam epoch += previous_epoch # for load previous model if hparam: writer = SummaryWriter('runs/' + model_start_time) checkpoint_cycle_flag = True num_batch = len(self.train_loader) self.model.train() for i in range(previous_epoch + 1, epoch + 1): total_loss = 0 iter_loss_history = [] Y_pred_all = [] Y_tr_all = [] if checkpoint_cycle_flag: checkpoint_cycle_flag = False checkpoint_start_time = time.time() for j, data in zip(s := trange(num_batch, leave=False), self.train_loader): Xtr, Ytr = data Xtr, Ytr = Xtr.to(**self.to_float_cuda, non_blocking=True), Ytr.cuda( non_blocking=True) ################################## Future changes ########################################################## loss, y_pred = self.train_fn(Xtr, Ytr) ############################################################################################################ total_loss += loss # Iter Book keeping Y_pred_all.append(y_pred) Y_tr_all.append(Ytr) iter_loss_history.append(loss) # update progress bar s.set_description(f'Epoch {i}/{epoch} Loss: {loss:.4f} ') avg_loss = total_loss / num_batch # Epoch Book keeping self.stats['iter_loss'].append(iter_loss_history) self.stats['avg_loss'].append(avg_loss) # Enter checkpoint block after first and last epoch and specify checkpoint interval if i % check_every_epoch == 0 or i == epoch: checkpoint_cycle_flag = True cur_lr = self.optimizer.param_groups[0]['lr'] # check train accuracy by using saved results during forward pass to save computation. Y_pred_all = torch.argmax(torch.cat(Y_pred_all), dim=1) Y_tr_all = torch.cat(Y_tr_all) train_accuracy = (Y_pred_all == Y_tr_all).float().mean() # check val accuracy val_accuracy, val_loss = self._check_accuracy(self.val_loader) # check update ratio ratio = self._check_update_ratio(cur_lr) print( f'Epoch: {i}/{epoch}, train loss: {avg_loss:.4f}, val loss: {val_loss:.4f}, train acc: {train_accuracy:.4f}, val acc: {val_accuracy:.4f},lr: {cur_lr:.4e}, update ratio: {ratio:.2e}, took {(time.time()-checkpoint_start_time):.2f} seconds' ) # Checkpoint Book keeping self.stats['train_acc'].append(train_accuracy) self.stats['val_acc'].append(val_accuracy) self.stats['ratio'].append(ratio) if hparam: writer.add_scalar('Epoch loss', avg_loss, i) writer.add_scalars('accuracy', { 'train': train_accuracy, 'val': val_accuracy }, i) # only save model checkpoint after half of the training process if i > epoch // 2: self._save_checkpoint(epoch=i) # decay learning rate after complete one epoch if self.lr_scheduler is not None: self.lr_scheduler.step()
def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]: """ Train the model """ if args.local_rank in [-1, 0]: tb_writer = SummaryWriter() args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) def collate(examples: List[torch.Tensor]): if tokenizer._pad_token is None: return pad_sequence(examples, batch_first=True) return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id) train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if (args.model_name_or_path and os.path.isfile( os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join(args.model_name_or_path, "scheduler.pt"))): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if args.model_name_or_path and os.path.exists(args.model_name_or_path): try: # set global_step to gobal_step of last saved checkpoint from model path checkpoint_suffix = args.model_name_or_path.split("-")[-1].split( "/")[0] global_step = int(checkpoint_suffix) epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: logger.info(" Starting fine-tuning.") tr_loss, logging_loss = 0.0, 0.0 model_to_resize = model.module if hasattr( model, "module") else model # Take care of distributed/parallel training model_to_resize.resize_token_embeddings(len(tokenizer)) model.zero_grad() train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) set_seed(args) # Added here for reproducibility for _ in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) inputs = inputs.to(args.device) labels = labels.to(args.device) model.train() loss = model(inputs, masked_lm_labels=labels) if args.mlm else model( inputs, labels=labels) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: # Log metrics if ( args.local_rank == -1 and args.evaluate_during_training ): # Only evaluate when single GPU otherwise metrics may not average well results = evaluate(args, model, tokenizer) for key, value in results.items(): tb_writer.add_scalar("eval_{}".format(key), value, global_step) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) logging_loss = tr_loss if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: checkpoint_prefix = "checkpoint" # Save model checkpoint output_dir = os.path.join( args.output_dir, "{}-{}".format(checkpoint_prefix, global_step)) os.makedirs(output_dir, exist_ok=True) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) _rotate_checkpoints(args, checkpoint_prefix) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank in [-1, 0]: tb_writer.close() return global_step, tr_loss / global_step
def train(gpu, args): """Create the model and start the training.""" rank = args.nr * args.num_gpus + gpu if gpu == 1: gpu = 3 dist.init_process_group(backend="nccl", world_size=args.world_size, rank=rank) if args.batch_size == 1 and args.use_bn is True: raise Exception torch.autograd.set_detect_anomaly(True) torch.manual_seed(args.torch_seed) torch.cuda.manual_seed(args.cuda_seed) torch.cuda.set_device(gpu) w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True gpu = gpu criterion = DiceBCELoss() # criterion = nn.CrossEntropyLoss(ignore_index=253) # Create network if args.model == 'DeepLab': model = DeeplabMulti(num_classes=args.num_classes) if args.restore_from is None: pass elif args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) elif args.restore_from is not None: saved_state_dict = torch.load(args.restore_from) model.load_state_dict(saved_state_dict) print("Loaded state dicts for model") # if args.restore_from is not None: # new_params = model.state_dict().copy() # for i in saved_state_dict: # # Scale.layer5.conv2d_list.3.weight # i_parts = i.split('.') # # print i_parts # if not args.num_classes == 19 or not i_parts[1] == 'layer5': # new_params['.'.join(i_parts[1:])] = saved_state_dict[i] # # print i_parts # model.load_state_dict(new_params) if not args.no_logging: if not os.path.isdir(args.log_dir): os.mkdir(args.log_dir) log_dir = os.path.join(args.log_dir, args.exp_dir) if not os.path.isdir(log_dir): os.mkdir(log_dir) if args.exp_name == "": exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d") else: exp_name = args.exp_name log_dir = os.path.join(log_dir, exp_name) writer = SummaryWriter(log_dir) model.train() # model.cuda(gpu) model = model.cuda(device=gpu) if args.num_gpus > 0 or torch.cuda.device_count() > 0: model = DistributedDataParallel(model, device_ids=[gpu], find_unused_parameters=True) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # cudnn.benchmark = True # init D model_D1 = FCDiscriminator(num_classes=args.num_classes) model_D2 = FCDiscriminator(num_classes=args.num_classes) start_epoch = 0 if "http" not in args.restore_from and args.restore_from is not None: root, extension = args.restore_from.strip().split(".") D1pth = root + "_D1." + extension D2pth = root + "_D2." + extension saved_state_dict = torch.load(D1pth) model_D1.load_state_dict(saved_state_dict) saved_state_dict = torch.load(D2pth) model_D2.load_state_dict(saved_state_dict) start_epoch = int(re.findall(r'[\d]+', root)[-1]) print("Loaded state dict for models D1 and D2") model_D1.train() # model_D1.cuda(gpu) model_D2.train() # model_D2.cuda(gpu) model_D1 = model_D1.cuda(device=gpu) model_D2 = model_D2.cuda(device=gpu) if args.num_gpus > 0 or torch.cuda.device_count() > 0: model_D1 = DistributedDataParallel(model_D1, device_ids=[gpu], find_unused_parameters=True) model_D2 = DistributedDataParallel(model_D2, device_ids=[gpu], find_unused_parameters=True) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_dataset = SyntheticSmokeTrain(args={}, dataset_limit=args.num_steps * args.iter_size * args.batch_size, image_shape=input_size, dataset_mean=IMG_MEAN) train_sampler = DistributedSampler(train_dataset, num_replicas=args.world_size, rank=rank, shuffle=True) trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) # trainloader = data.DataLoader( # GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size, # crop_size=input_size, # scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), # batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_iter = enumerate(trainloader) print("Length of train dataloader: ", len(trainloader)) target_dataset = SimpleSmokeVal(args={}, image_size=input_size_target, dataset_mean=IMG_MEAN) target_sampler = DistributedSampler(target_dataset, num_replicas=args.world_size, rank=rank, shuffle=True) targetloader = data.DataLoader(target_dataset, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, sampler=target_sampler) # targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target, # max_iters=args.num_steps * args.iter_size * args.batch_size, # crop_size=input_size_target, # scale=False, mirror=args.random_mirror, mean=IMG_MEAN, # set=args.set), # batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, # pin_memory=True) targetloader_iter = enumerate(targetloader) print("Length of train dataloader: ", len(targetloader)) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.module.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D2.zero_grad() if args.gan == 'Vanilla': bce_loss = torch.nn.BCEWithLogitsLoss() elif args.gan == 'LS': bce_loss = torch.nn.MSELoss() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear') # labels for adversarial training source_label = 0 target_label = 1 for i_iter in range(start_epoch, args.num_steps): loss_seg_value1 = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 loss_seg_value2 = 0 loss_adv_target_value2 = 0 loss_D_value2 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D1.zero_grad() optimizer_D2.zero_grad() adjust_learning_rate_D(optimizer_D1, i_iter) adjust_learning_rate_D(optimizer_D2, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False for param in model_D2.parameters(): param.requires_grad = False # train with source # try: _, batch = next(trainloader_iter) #.next() # except StopIteration: # trainloader = data.DataLoader( # SyntheticSmokeTrain(args={}, dataset_limit=args.num_steps * args.iter_size * args.batch_size, # image_shape=input_size, dataset_mean=IMG_MEAN), # batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) # trainloader_iter = iter(trainloader) # _, batch = next(trainloader_iter) images, labels, _, _ = batch images = Variable(images).cuda(gpu) # print("Shape of labels", labels.shape) # print("Are labels all zero? ") # for i in range(labels.shape[0]): # print("{}: All zero? {}".format(i, torch.all(labels[i]==0))) # print("{}: All 255? {}".format(i, torch.all(labels[i]==255))) # print("{}: Mean = {}".format(i, torch.mean(labels[i]))) pred1, pred2 = model(images) # print("Pred1 and Pred2 original size: {}, {}".format(pred1.shape, pred2.shape)) pred1 = interp(pred1) pred2 = interp(pred2) # print("Pred1 and Pred2 upsampled size: {}, {}".format(pred1.shape, pred2.shape)) # for pred, name in zip([pred1, pred2], ['pred1', 'pred2']): # print(name) # for i in range(pred.shape[0]): # print("{}: All zero? {}".format(i, torch.all(pred[i]==0))) # print("{}: All 255? {}".format(i, torch.all(pred[i]==255))) # print("{}: Mean = {}".format(i, torch.mean(pred[i]))) loss_seg1 = loss_calc(pred1, labels, gpu, criterion) loss_seg2 = loss_calc(pred2, labels, gpu, criterion) loss = loss_seg2 + args.lambda_seg * loss_seg1 # proper normalization loss = loss / args.iter_size loss.backward() # print("Seg1 loss: ",loss_seg1, args.iter_size) # print("Seg2 loss: ",loss_seg2, args.iter_size) loss_seg_value1 += loss_seg1.data.cpu().item() / args.iter_size loss_seg_value2 += loss_seg2.data.cpu().item() / args.iter_size # train with target # try: _, batch = next(targetloader_iter) #.next() # except StopIteration: # targetloader = data.DataLoader( # SimpleSmokeVal(args = {}, image_size=input_size_target, dataset_mean=IMG_MEAN), # batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, # pin_memory=True) # targetloader_iter = iter(targetloader) # _, batch = next(targetloader_iter) images, _, _ = batch images = Variable(images).cuda(gpu) pred_target1, pred_target2 = model(images) pred_target1 = interp_target(pred_target1) pred_target2 = interp_target(pred_target2) D_out1 = model_D1(F.softmax(pred_target1, dim=1)) D_out2 = model_D2(F.softmax(pred_target2, dim=1)) loss_adv_target1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(source_label)).cuda(gpu)) loss_adv_target2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(source_label)).cuda(gpu)) loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2 loss = loss / args.iter_size loss.backward() loss_adv_target_value1 += loss_adv_target1.data.cpu().item( ) / args.iter_size loss_adv_target_value2 += loss_adv_target2.data.cpu().item( ) / args.iter_size # train D # bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True for param in model_D2.parameters(): param.requires_grad = True # train with source pred1 = pred1.detach() pred2 = pred2.detach() D_out1 = model_D1(F.softmax(pred1, dim=1)) D_out2 = model_D2(F.softmax(pred2, dim=1)) loss_D1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(source_label)).cuda(gpu)) loss_D2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(source_label)).cuda(gpu)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.data.cpu().item() loss_D_value2 += loss_D2.data.cpu().item() # train with target pred_target1 = pred_target1.detach() pred_target2 = pred_target2.detach() D_out1 = model_D1(F.softmax(pred_target1, dim=1)) D_out2 = model_D2(F.softmax(pred_target2, dim=1)) loss_D1 = bce_loss( D_out1, Variable( torch.FloatTensor( D_out1.data.size()).fill_(target_label)).cuda(gpu)) loss_D2 = bce_loss( D_out2, Variable( torch.FloatTensor( D_out2.data.size()).fill_(target_label)).cuda(gpu)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D2 = loss_D2 / args.iter_size / 2 loss_D1.backward() loss_D2.backward() loss_D_value1 += loss_D1.data.cpu().item() loss_D_value2 += loss_D2.data.cpu().item() optimizer.step() optimizer_D1.step() optimizer_D2.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}' .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2)) writer.add_scalar(f'loss/train/segmentation/1', loss_seg_value1, i_iter) writer.add_scalar(f'loss/train/segmentation/2', loss_seg_value2, i_iter) writer.add_scalar(f'loss/train/adversarial/1', loss_adv_target_value1, i_iter) writer.add_scalar(f'loss/train/adversarial/2', loss_adv_target_value2, i_iter) writer.add_scalar(f'loss/train/domain/1', loss_D_value1, i_iter) writer.add_scalar(f'loss/train/domain/2', loss_D_value2, i_iter) if i_iter >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'smoke_cross_entropy_multigpu_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D1.state_dict(), osp.join( args.snapshot_dir, 'smoke_cross_entropy_multigpu_' + str(args.num_steps_stop) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join( args.snapshot_dir, 'smoke_cross_entropy_multigpu_' + str(args.num_steps_stop) + '_D2.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'smoke_cross_entropy_multigpu_' + str(i_iter) + '.pth')) torch.save( model_D1.state_dict(), osp.join( args.snapshot_dir, 'smoke_cross_entropy_multigpu_' + str(i_iter) + '_D1.pth')) torch.save( model_D2.state_dict(), osp.join( args.snapshot_dir, 'smoke_cross_entropy_multigpu_' + str(i_iter) + '_D2.pth')) writer.flush()
class Trainer: """ Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for Transformers. """ model: PreTrainedModel args: TrainingArguments data_collator: DataCollator train_dataset: Optional[Dataset] eval_dataset: Optional[Dataset] compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None prediction_loss_only: bool tb_writer: Optional["SummaryWriter"] = None optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None global_step: Optional[int] = None epoch: Optional[float] = None def __init__( self, model: PreTrainedModel, args: TrainingArguments, neptune, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Dataset] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, prediction_loss_only=False, tb_writer: Optional["SummaryWriter"] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None, ): """ Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for Transformers. Args: prediction_loss_only: (Optional) in evaluation and prediction, only return the loss """ self.model = model.to(args.device) self.args = args self.neptune = neptune if data_collator is not None: self.data_collator = data_collator else: self.data_collator = DefaultDataCollator() self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.compute_metrics = compute_metrics self.prediction_loss_only = prediction_loss_only self.optimizers = optimizers if tb_writer is not None: self.tb_writer = tb_writer elif is_tensorboard_available() and self.is_world_master(): self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir) if not is_tensorboard_available(): logger.warning( "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it." ) if is_wandb_available(): self._setup_wandb() else: logger.info( "You are instantiating a Trainer but W&B is not installed. To use wandb logging, " "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface." ) set_seed(self.args.seed) # Create output directory if needed if self.is_world_master(): os.makedirs(self.args.output_dir, exist_ok=True) if is_torch_tpu_available(): # Set an xla_device flag on the model's config. # We'll find a more elegant and not need to do this in the future. self.model.config.xla_device = True def get_train_dataloader(self) -> DataLoader: if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") if is_torch_tpu_available(): train_sampler = get_tpu_sampler(self.train_dataset) else: train_sampler = (SequentialSampler(self.train_dataset) if self.args.local_rank == -1 else SequentialDistributedSampler(self.train_dataset)) data_loader = DataLoader( self.train_dataset, batch_size=self.args.train_batch_size, sampler=train_sampler, collate_fn=self.data_collator.collate_batch, drop_last=self.args.dataloader_drop_last, ) return data_loader def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None ) -> DataLoader: if eval_dataset is None and self.eval_dataset is None: raise ValueError("Trainer: evaluation requires an eval_dataset.") eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset if is_torch_tpu_available(): sampler = SequentialDistributedSampler( eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) elif self.args.local_rank != -1: sampler = SequentialDistributedSampler(eval_dataset) else: sampler = SequentialSampler(eval_dataset) data_loader = DataLoader( eval_dataset, sampler=sampler, batch_size=self.args.eval_batch_size, collate_fn=self.data_collator.collate_batch, drop_last=self.args.dataloader_drop_last, ) return data_loader def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: # We use the same batch_size as for eval. if is_torch_tpu_available(): sampler = SequentialDistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) elif self.args.local_rank != -1: sampler = SequentialDistributedSampler(test_dataset) else: sampler = SequentialSampler(test_dataset) data_loader = DataLoader( test_dataset, sampler=sampler, batch_size=self.args.eval_batch_size, collate_fn=self.data_collator.collate_batch, drop_last=self.args.dataloader_drop_last, ) return data_loader def get_optimizers( self, num_training_steps: int ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]: """ Setup the optimizer and the learning rate scheduler. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init, or override this method in a subclass. """ if self.optimizers is not None: return self.optimizers # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps) return optimizer, scheduler def _setup_wandb(self): """ Setup the optional Weights & Biases (`wandb`) integration. One can override this method to customize the setup if needed. Find more information at https://docs.wandb.com/huggingface You can also override the following environment variables: Environment: WANDB_WATCH: (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging or "all" to log gradients and parameters WANDB_PROJECT: (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project WANDB_DISABLED: (Optional): boolean - defaults to false, set to "true" to disable wandb entirely """ if self.is_world_master(): logger.info( 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' ) wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args)) # keep track of model topology and gradients if os.getenv("WANDB_WATCH") != "false": wandb.watch(self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)) def num_examples(self, dataloader: DataLoader) -> int: """ Helper to get num of examples from a DataLoader, by accessing its Dataset. """ return len(dataloader.dataset) def train(self, model_path: Optional[str] = None): """ Main training entry point. Args: model_path: (Optional) Local path to model if model to train has been instantiated from a local path If present, we will try reloading the optimizer/scheduler states from there. """ train_dataloader = self.get_train_dataloader() if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = (self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1) else: t_total = int( len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs optimizer, scheduler = self.get_optimizers(num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if (model_path is not None and os.path.isfile(os.path.join(model_path, "optimizer.pt")) and os.path.isfile(os.path.join(model_path, "scheduler.pt"))): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)) scheduler.load_state_dict( torch.load(os.path.join(model_path, "scheduler.pt"))) model = self.model if self.args.fp16: if not is_apex_available(): raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize( model, optimizer, opt_level=self.args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if self.args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[self.args.local_rank], output_device=self.args.local_rank, find_unused_parameters=True, ) if self.tb_writer is not None: self.tb_writer.add_text("args", self.args.to_json_string()) self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={}) # Train! if is_torch_tpu_available(): total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size( ) else: total_train_batch_size = (self.args.train_batch_size * self.args.gradient_accumulation_steps * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)) logger.info("***** Running training *****") logger.info(" Num examples = %d", self.num_examples(train_dataloader)) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) self.global_step = 0 self.epoch = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if model_path is not None: # set global_step to global_step of last saved checkpoint from model path try: self.global_step = int(model_path.split("-")[-1].split("/")[0]) epochs_trained = self.global_step // ( len(train_dataloader) // self.args.gradient_accumulation_steps) steps_trained_in_current_epoch = self.global_step % ( len(train_dataloader) // self.args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", self.global_step) logger.info( " Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: self.global_step = 0 logger.info(" Starting fine-tuning.") accumulation_loss = 0.0 tr_loss = 0.0 logging_loss = 0.0 model.zero_grad() train_iterator = trange(epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()) for epoch in train_iterator: if isinstance(train_dataloader, DataLoader) and isinstance( train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) if is_torch_tpu_available(): parallel_loader = pl.ParallelLoader( train_dataloader, [self.args.device]).per_device_loader(self.args.device) epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master()) else: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master()) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue loss = self._training_step(model, inputs, optimizer) accumulation_loss += loss tr_loss += loss if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)): if self.args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), self.args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) if is_torch_tpu_available(): xm.optimizer_step(optimizer) else: optimizer.step() scheduler.step() model.zero_grad() self.global_step += 1 self.epoch = epoch + (step + 1) / len(train_dataloader) if is_torch_tpu_available(): if xm.get_ordinal() == 0: self.neptune.log_metric('loss', self.global_step, accumulation_loss) else: self.neptune.log_metric('loss', self.global_step, accumulation_loss) accumulation_loss = 0.0 if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (self.global_step == 1 and self.args.logging_first_step): logs: Dict[str, float] = {} logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps # backward compatibility for pytorch schedulers logs["learning_rate"] = ( scheduler.get_last_lr()[0] if version.parse(torch.__version__) >= version.parse("1.4") else scheduler.get_lr()[0]) logging_loss = tr_loss self._log(logs) if self.args.evaluate_during_training: self.evaluate() if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: # In all cases (even distributed/parallel), self.model is always a reference # to the model we want to save. if hasattr(model, "module"): assert model.module is self.model else: assert model is self.model # Save model checkpoint output_dir = os.path.join( self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}") self.save_model(output_dir) if self.is_world_master(): self._rotate_checkpoints() if is_torch_tpu_available(): xm.rendezvous("saving_optimizer_states") xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) elif self.is_world_master(): torch.save( optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save( scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) if self.args.max_steps > 0 and self.global_step > self.args.max_steps: epoch_iterator.close() break if self.args.max_steps > 0 and self.global_step > self.args.max_steps: train_iterator.close() break if self.args.tpu_metrics_debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) if self.tb_writer: self.tb_writer.close() logger.info( "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n" ) return TrainOutput(self.global_step, tr_loss / self.global_step) def _log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None: if self.epoch is not None: logs["epoch"] = self.epoch if self.tb_writer: for k, v in logs.items(): if isinstance(v, (int, float)): self.tb_writer.add_scalar(k, v, self.global_step) else: logger.warning( "Trainer is attempting to log a value of " '"%s" of type %s for key "%s" as a scalar. ' "This invocation of Tensorboard's writer.add_scalar() " "is incorrect so we dropped this attribute.", v, type(v), k, ) self.tb_writer.flush() if is_wandb_available(): if self.is_world_master(): wandb.log(logs, step=self.global_step) output = json.dumps({**logs, **{"step": self.global_step}}) if iterator is not None: iterator.write(output) else: print(output) def _training_step(self, model: nn.Module, inputs: Dict[str, torch.Tensor], optimizer: torch.optim.Optimizer) -> float: model.train() for k, v in inputs.items(): inputs[k] = v.to(self.args.device) outputs = model(**inputs) loss = outputs[ 0] # model outputs are always tuple in transformers (see doc) if self.args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if self.args.gradient_accumulation_steps > 1: loss = loss / self.args.gradient_accumulation_steps if self.args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() return loss.item() def is_local_master(self) -> bool: if is_torch_tpu_available(): return xm.is_master_ordinal(local=True) else: return self.args.local_rank in [-1, 0] def is_world_master(self) -> bool: """ This will be True only in one process, even in distributed mode, even when training on multiple machines. """ if is_torch_tpu_available(): return xm.is_master_ordinal(local=False) else: return self.args.local_rank == -1 or torch.distributed.get_rank( ) == 0 def save_model(self, output_dir: Optional[str] = None): """ Saving best-practices: if you use default names for the model, you can reload it using from_pretrained(). Will only save from the world_master process (unless in TPUs). """ if is_torch_tpu_available(): self._save_tpu(output_dir) elif self.is_world_master(): self._save(output_dir) def _save_tpu(self, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else self.args.output_dir logger.info("Saving model checkpoint to %s", output_dir) if xm.is_master_ordinal(): os.makedirs(output_dir, exist_ok=True) torch.save(self.args, os.path.join(output_dir, "training_args.bin")) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not isinstance(self.model, PreTrainedModel): raise ValueError( "Trainer.model appears to not be a PreTrainedModel") xm.rendezvous("saving_checkpoint") self.model.save_pretrained(output_dir) def _save(self, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info("Saving model checkpoint to %s", output_dir) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not isinstance(self.model, PreTrainedModel): raise ValueError( "Trainer.model appears to not be a PreTrainedModel") self.model.save_pretrained(output_dir) # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, "training_args.bin")) def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]: ordering_and_checkpoint_path = [] glob_checkpoints = [ str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*") ] for path in glob_checkpoints: if use_mtime: ordering_and_checkpoint_path.append( (os.path.getmtime(path), path)) else: regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) if regex_match and regex_match.groups(): ordering_and_checkpoint_path.append( (int(regex_match.groups()[0]), path)) checkpoints_sorted = sorted(ordering_and_checkpoint_path) checkpoints_sorted = [ checkpoint[1] for checkpoint in checkpoints_sorted ] return checkpoints_sorted def _rotate_checkpoints(self, use_mtime=False) -> None: if self.args.save_total_limit is None or self.args.save_total_limit <= 0: return # Check if we should delete older checkpoint(s) checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime) if len(checkpoints_sorted) <= self.args.save_total_limit: return number_of_checkpoints_to_delete = max( 0, len(checkpoints_sorted) - self.args.save_total_limit) checkpoints_to_be_deleted = checkpoints_sorted[: number_of_checkpoints_to_delete] for checkpoint in checkpoints_to_be_deleted: logger.info( "Deleting older checkpoint [{}] due to args.save_total_limit". format(checkpoint)) shutil.rmtree(checkpoint) def evaluate( self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None, ) -> Dict[str, float]: """ Run evaluation and return metrics. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent. Args: eval_dataset: (Optional) Pass a dataset if you wish to override the one on the instance. Returns: A dict containing: - the eval loss - the potential metrics computed from the predictions """ eval_dataloader = self.get_eval_dataloader(eval_dataset) output = self._prediction_loop(eval_dataloader, description="Evaluation") self._log(output.metrics) if self.args.tpu_metrics_debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) return output.metrics def predict(self, test_dataset: Dataset) -> PredictionOutput: """ Run prediction and return predictions and potential metrics. Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method will also return metrics, like in evaluate(). """ test_dataloader = self.get_test_dataloader(test_dataset) return self._prediction_loop(test_dataloader, description="Prediction") def _prediction_loop( self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None) -> PredictionOutput: """ Prediction/evaluation loop, shared by `evaluate()` and `predict()`. Works both with or without labels. """ prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only model = self.model # multi-gpu eval if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) else: model = self.model # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. batch_size = dataloader.batch_size logger.info("***** Running %s *****", description) logger.info(" Num examples = %d", self.num_examples(dataloader)) logger.info(" Batch size = %d", batch_size) eval_losses: List[float] = [] preds: torch.Tensor = None label_ids: torch.Tensor = None model.eval() if is_torch_tpu_available(): dataloader = pl.ParallelLoader( dataloader, [self.args.device]).per_device_loader(self.args.device) for inputs in tqdm(dataloader, desc=description): has_labels = any( inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"]) for k, v in inputs.items(): inputs[k] = v.to(self.args.device) with torch.no_grad(): outputs = model(**inputs) if has_labels: step_eval_loss, logits = outputs[:2] eval_losses += [step_eval_loss.mean().item()] else: logits = outputs[0] if not prediction_loss_only: if preds is None: preds = logits.detach() else: preds = torch.cat((preds, logits.detach()), dim=0) if inputs.get("labels") is not None: if label_ids is None: label_ids = inputs["labels"].detach() else: label_ids = torch.cat( (label_ids, inputs["labels"].detach()), dim=0) if self.args.local_rank != -1: # In distributed mode, concatenate all results from all nodes: if preds is not None: preds = self.distributed_concat( preds, num_total_examples=self.num_examples(dataloader)) if label_ids is not None: label_ids = self.distributed_concat( label_ids, num_total_examples=self.num_examples(dataloader)) elif is_torch_tpu_available(): # tpu-comment: Get all predictions and labels from all worker shards of eval dataset if preds is not None: preds = xm.mesh_reduce("eval_preds", preds, torch.cat) if label_ids is not None: label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat) # Finally, turn the aggregated tensors into numpy arrays. if preds is not None: preds = preds.cpu().numpy() if label_ids is not None: label_ids = label_ids.cpu().numpy() if self.compute_metrics is not None and preds is not None and label_ids is not None: metrics = self.compute_metrics( EvalPrediction(predictions=preds, label_ids=label_ids)) else: metrics = {} if len(eval_losses) > 0: metrics["eval_loss"] = np.mean(eval_losses) # Prefix all keys with eval_ for key in list(metrics.keys()): if not key.startswith("eval_"): metrics[f"eval_{key}"] = metrics.pop(key) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor: assert self.args.local_rank != -1 output_tensors = [ tensor.clone() for _ in range(torch.distributed.get_world_size()) ] torch.distributed.all_gather(output_tensors, tensor) concat = torch.cat(output_tensors, dim=0) # truncate the dummy elements added by SequentialDistributedSampler output = concat[:num_total_examples] return output
def main(): # load expert data print(args.data_set_path) dataset = ExpertDataSet(args.data_set_path) data_loader = data.DataLoader( dataset=dataset, batch_size=args.expert_batch_size, shuffle=True, num_workers=0 ) p_state_sizes = [args.n_state, args.n_state, args.n_state, args.n_state + args.n_action] p_action_sizes = [args.n_onehot_action, args.n_multihot_action, args.n_continuous_action, args.n_continuous_state - args.n_continuous_action] d_state_sizes = [args.n_state, args.n_state, args.n_state, args.n_state] d_action_sizes = [args.n_onehot_action, args.n_multihot_action, args.n_continuous_action, args.n_continuous_state - args.n_continuous_action] policy = MultiPolicy(p_state_sizes, p_action_sizes, onehot_action_sections, onehot_state_sections, state_0 = dataset.state) discriminator = MultiDiscriminator(d_state_sizes, d_action_sizes) discriminator_criterion = nn.BCELoss() if write_scalar: writer = SummaryWriter(log_dir='runs/' + model_name) # load net models if load_model: discriminator = torch.load('./model_pkl/multi_policy/D_' + model_name + '.pkl') policy = torch.load('./model_pkl/multi_policy/P_' + model_name + '.pkl') print('############# start training ##############') # update discriminator num = 0 for ep in tqdm(range(args.training_epochs)): # collect data from environment for ppo update policy.train() discriminator.train() start_time = time.time() memory, n_trajs = policy.collect_samples(batch_size=args.sample_batch_size) # print('sample_data_time:{}'.format(time.time()-start_time)) batch = memory.sample() gen_state = torch.cat(batch.state, dim=1).reshape(n_trajs * args.sample_traj_length, -1).detach() gen_action = torch.cat(batch.action, dim=1).reshape(n_trajs * args.sample_traj_length, -1).detach() gen_next_state = torch.cat(batch.next_state, dim=1).reshape(n_trajs * args.sample_traj_length, -1).detach() old_log_prob = torch.cat(batch.old_log_prob, dim=1).reshape(n_trajs * args.sample_traj_length, -1).detach() mask = torch.cat(batch.mask, dim=1).reshape(n_trajs * args.sample_traj_length, -1).detach() gen_d_state, gen_d_action = make_d_inputs(gen_state, gen_action, gen_next_state) if ep % 1 == 0: # if (d_slow_flag and ep % 50 == 0) or (not d_slow_flag and ep % 1 == 0): d_loss = torch.empty(0, device=device) p_loss = torch.empty(0, device=device) v_loss = torch.empty(0, device=device) gen_r = torch.empty(0, device=device) expert_r = torch.empty(0, device=device) for expert_state_batch, expert_action_batch, expert_next_state_batch in data_loader: expert_d_state, expert_d_action = make_d_inputs(expert_state_batch.to(device), expert_action_batch.to(device), expert_next_state_batch.to(device)) gen_r = discriminator(gen_d_state, gen_d_action) expert_r = discriminator(expert_d_state, expert_d_action) discriminator.optimizer.zero_grad() d_loss = discriminator_criterion(gen_r, torch.zeros(gen_r.shape, device=device)) + \ discriminator_criterion(expert_r,torch.ones(expert_r.shape, device=device)) variance = 0.5 * torch.var(gen_r.to(device)) + 0.5 * torch.var(expert_r.to(device)) total_d_loss = d_loss - 10 * variance d_loss.backward() # total_d_loss.backward() discriminator.optimizer.step() if write_scalar: writer.add_scalar('loss/d_loss', d_loss, ep) writer.add_scalar('loss/total_d_loss', total_d_loss, ep) writer.add_scalar('loss/variance', 10 * variance, ep) if ep % 1 == 0: # update PPO gen_r = discriminator(gen_d_state, gen_d_action) optimize_iter_num = int(math.ceil(gen_state.shape[0] / args.ppo_mini_batch_size)) for ppo_ep in range(args.ppo_optim_epoch): for i in range(optimize_iter_num): num += 1 index = slice(i * args.ppo_mini_batch_size, min((i + 1) * args.ppo_mini_batch_size, gen_state.shape[0])) gen_state_batch, gen_action_batch, gen_next_state_batch, old_log_prob_batch, mask_batch, gen_r_batch = \ gen_state[index], gen_action[index], gen_next_state[index], old_log_prob[index], mask[index], gen_r[index] v_loss, p_loss = ppo_step(policy, gen_state_batch, gen_action_batch, gen_next_state_batch, gen_r_batch, old_log_prob_batch, mask_batch, args.ppo_clip_epsilon) policy.eval() discriminator.eval() gen_d_state, gen_d_action = make_d_inputs(gen_state, gen_action, gen_next_state) expert_d_state, expert_d_action = make_d_inputs(expert_state_batch.to(device), expert_action_batch.to(device), expert_next_state_batch.to(device)) gen_r = discriminator(gen_d_state, gen_d_action) expert_r = discriminator(expert_d_state, expert_d_action) gen_r_noise = gen_r.mean(dim=0) expert_r_noise = expert_r.mean(dim=0) gen_r = discriminator(gen_d_state, gen_d_action, noise=False) expert_r = discriminator(expert_d_state, expert_d_action, noise=False) if write_scalar: writer.add_scalar('gen_r_accurate/onehot', gen_r.mean(dim=0)[0], ep) writer.add_scalar('gen_r_accurate/multihot', gen_r.mean(dim=0)[1], ep) writer.add_scalar('gen_r_accurate/continuous', gen_r.mean(dim=0)[2], ep) writer.add_scalar('gen_r_accurate/next_state', gen_r.mean(dim=0)[3], ep) writer.add_scalar('expert_r_accurate/onehot', expert_r.mean(dim=0)[0], ep) writer.add_scalar('expert_r_accurate/multihot', expert_r.mean(dim=0)[1], ep) writer.add_scalar('expert_r_accurate/continuous', expert_r.mean(dim=0)[2], ep) writer.add_scalar('expert_r_accurate/next_state', expert_r.mean(dim=0)[3], ep) writer.add_scalar('gen_r_with_noise/onehot', gen_r_noise[0], ep) writer.add_scalar('gen_r_with_noise/multihot', gen_r_noise[1], ep) writer.add_scalar('gen_r_with_noise/continuous', gen_r_noise[2], ep) writer.add_scalar('gen_r_with_noise/next_state', gen_r_noise[3], ep) writer.add_scalar('expert_r_with_noise/onehot', expert_r_noise[0], ep) writer.add_scalar('expert_r_with_noise/multihot', expert_r_noise[1], ep) writer.add_scalar('expert_r_with_noise/continuous', expert_r_noise[2], ep) writer.add_scalar('expert_r_with_noise/next_state', expert_r_noise[3], ep) writer.add_scalar('total/gen_r_accurate', gen_r.mean(), ep) writer.add_scalar('total/expert_r_accurate', expert_r.mean(), ep) writer.add_scalar('total/gen_r_with_noise', gen_r_noise.mean(), ep) writer.add_scalar('total/expert_r_with_noise', expert_r_noise.mean(), ep) print('#' * 5 + 'training episode:{}'.format(ep) + '#' * 5) print('gen_r_noise:', gen_r_noise) print('expert_r_noise:', expert_r_noise) print('gen_r:', gen_r.mean(dim=0)) print('expert_r:', expert_r.mean(dim=0)) print('d_loss', d_loss.item()) # save models if model_name is not None: torch.save(discriminator, './model_pkl/multi_policy/D_' + model_name + '.pkl') torch.save(policy, './model_pkl/multi_policy/P_' + model_name + '.pkl') memory.clear_memory()
for i, batch in enumerate(tqdm(test_loader)): inp, gt, gt_flag = process_valBatch(batch) inp = Variable(inp).float().to(device) with autocast(enabled=args.amp): out = model(inp) out = out.type(inp.dtype) for b in range(len(batch['filename'])): metrics = saver.CalcNSave(out[b,...].detach().cpu().squeeze(), inp[b,...].detach().cpu().squeeze(), gt[b,...].squeeze().float() if gt_flag[b] else None, batch['filename'][b].split(".")[0]) if metrics is not None: metrics['file'] = batch['filename'] test_metrics.append(metrics) ssim = round(metrics['SSIMOut'],4) test_ssim.append(ssim) runningSSIM.append(ssim) logging.info('[%d/%d] Test SSIM: %.4f' % (i, len(test_loader), ssim)) #For tensorboard if i % args.log_freq == 0: niter = len(test_loader)+i tb_writer.add_scalar('Test/SSIM', median(runningSSIM), niter) runningSSIM = [] if len(test_metrics) > 0: print("Avg SSIM: "+str(median(test_ssim))) df = pd.DataFrame.from_dict(test_metrics) df.to_csv(os.path.join(args.save_path, 'Results.csv'), index=False)
class Logger: """ A general-purpose logger. Makes it easy to save diagnostics, hyper-parameter configurations, the state of a training run, and the trained model. """ def __init__( self, log_dir, output_fname='progress.csv', debug: bool = False, exp_name=None, level: int = 1, # verbosity level use_tensor_board=True, verbose=True): """ Initialize a Logger. Args: log_dir (string): A directory for saving results to. If ``None``, defaults to a temp directory of the form ``/tmp/experiments/somerandomnumber``. output_fname (string): Name for the tab-separated-value file containing metrics logged throughout a training run. Defaults to ``progress.txt``. exp_name (string): Experiment name. If you run multiple training runs and give them all the same ``exp_name``, the plotter will know to group them. (Use case: if you run the same hyperparameter configuration with multiple random seeds, you should give them all the same ``exp_name``.) """ self.log_dir = log_dir self.debug = debug if proc_id() == 0 else False self.level = level # only the MPI root process is allowed to print information to console self.verbose = verbose if proc_id() == 0 else False if proc_id() == 0: os.makedirs(self.log_dir, exist_ok=True) self.output_file = open(osp.join(self.log_dir, output_fname), 'w') atexit.register(self.output_file.close) print( colorize(f"Logging data to {self.output_file.name}", 'cyan', bold=True)) else: self.output_file = None self.epoch = 0 self.first_row = True self.log_headers = [] self.log_current_row = {} self.exp_name = exp_name self.torch_saver_elements = None # Setup tensor board logging if enabled and MPI root process self.summary_writer = SummaryWriter(os.path.join(self.log_dir, 'tb')) \ if use_tensor_board and proc_id() == 0 else None def close(self): """Close opened output files immediately after training in order to avoid number of open files overflow. Avoids the following error: OSError: [Errno 24] Too many open files """ if proc_id() == 0: self.output_file.close() def debug(self, msg, color='yellow'): """Print a colorized message to stdout.""" if self.debug: print(colorize(msg, color, bold=False)) def log(self, msg, color='green'): """Print a colorized message to stdout.""" if self.verbose and self.level > 0: print(colorize(msg, color, bold=False)) def log_tabular(self, key, val): """ Log a value of some diagnostic. Call this only once for each diagnostic quantity, each iteration. After using ``log_tabular`` to store values for each diagnostic, make sure to call ``dump_tabular`` to write them out to file and stdout (otherwise they will not get saved anywhere). """ if self.first_row: self.log_headers.append(key) else: assert key in self.log_headers, "Trying to introduce a new key %s that you didn't include in the first iteration" % key assert key not in self.log_current_row, "You already set %s this iteration. Maybe you forgot to call dump_tabular()" % key self.log_current_row[key] = val def save_config(self, config): """ Log an experiment configuration. Call this once at the top of your experiment, passing in all important config vars as a dict. This will serialize the config to JSON, while handling anything which can't be serialized in a graceful way (writing as informative a string as possible). Example use: .. code-block:: python logger = EpochLogger(**logger_kwargs) logger.save_config(locals()) """ if proc_id() == 0: # only root process logs configurations config_json = convert_json(config) if self.exp_name is not None: config_json['exp_name'] = self.exp_name output = json.dumps(config_json, separators=(',', ':\t'), indent=4, sort_keys=True) if self.verbose and self.level > 0: print(colorize('Run with config:', color='yellow', bold=True)) print(output) with open(osp.join(self.log_dir, "config.json"), 'w') as out: out.write(output) def save_state(self, state_dict, itr=None): """ Saves the state of an experiment. To be clear: this is about saving *state*, not logging diagnostics. All diagnostic logging is separate from this function. This function will save whatever is in ``state_dict``---usually just a copy of the environment---and the most recent parameters for the model you previously set up saving for with ``setup_tf_saver``. Call with any frequency you prefer. If you only want to maintain a single state and overwrite it at each call with the most recent version, leave ``itr=None``. If you want to keep all of the states you save, provide unique (increasing) values for 'itr'. Args: state_dict (dict): Dictionary containing essential elements to describe the current state of training. itr: An int, or None. Current iteration of training. """ if proc_id() == 0: fname = 'state.pkl' if itr is None else 'state%d.pkl' % itr try: joblib.dump(state_dict, osp.join(self.log_dir, fname)) except: self.log('Warning: could not pickle state_dict.', color='red') if hasattr(self, 'torch_saver_elements'): self.torch_save(itr) def setup_torch_saver(self, what_to_save): """ Set up easy model saving for a single PyTorch model. Because PyTorch saving and loading is especially painless, this is very minimal; we just need references to whatever we would like to pickle. This is integrated into the logger because the logger knows where the user would like to save information about this training run. Args: what_to_save: Any PyTorch model or serializable object containing PyTorch models. """ self.torch_saver_elements = what_to_save def torch_save(self, itr=None): """ Saves the PyTorch model (or models). """ if proc_id() == 0: self.log('Save model to disk...') assert self.torch_saver_elements is not None,\ "First have to setup saving with self.setup_torch_saver" fpath = 'torch_save' fpath = osp.join(self.log_dir, fpath) fname = 'model' + ('%d' % itr if itr is not None else '') + '.pt' fname = osp.join(fpath, fname) os.makedirs(fpath, exist_ok=True) with warnings.catch_warnings(): warnings.simplefilter("ignore") # We are using a non-recommended way of saving PyTorch models, # by pickling whole objects (which are dependent on the exact # directory structure at the time of saving) as opposed to # just saving network weights. This works sufficiently well # for the purposes of Spinning Up, but you may want to do # something different for your personal PyTorch project. # We use a catch_warnings() context to avoid the warnings about # not being able to save the source code. torch.save(self.torch_saver_elements, fname) torch.save(self.torch_saver_elements.state_dict(), fname) self.log('Done.') def dump_tabular(self) -> None: """ Write all of the diagnostics from the current iteration. Writes both to stdout, and to the output file. """ if proc_id() == 0: vals = list() self.epoch += 1 # Print formatted information into console key_lens = [len(key) for key in self.log_headers] max_key_len = max(15, max(key_lens)) keystr = '%' + '%d' % max_key_len fmt = "| " + keystr + "s | %15s |" n_slashes = 22 + max_key_len print("-" * n_slashes) if self.verbose and self.level > 0 else None for key in self.log_headers: val = self.log_current_row.get(key, "") valstr = "%8.3g" % val if hasattr(val, "__float__") else val if self.verbose and self.level > 0: print(fmt % (key, valstr)) vals.append(val) if self.verbose and self.level > 0: print("-" * n_slashes, flush=True) # Write into the output file (can be any text file format, e.g. CSV) if self.output_file is not None: if self.first_row: self.output_file.write(",".join(self.log_headers) + "\n") self.output_file.write(",".join(map(str, vals)) + "\n") self.output_file.flush() if self.summary_writer is not None: [ self.summary_writer.add_scalar(k, v, global_step=self.epoch) for (k, v) in zip(self.log_headers, vals) ] # Flushes the event file to disk. Call this method to make sure # that all pending events have been written to disk. self.summary_writer.flush() # free logged information in all processes... self.log_current_row.clear() self.first_row = False
class SupervisedNetwork: def __init__(self, config): now = datetime.now() date_time = now.strftime("%m-%d-%H-%M-%S") self.tensorboard_writer = SummaryWriter(log_dir='runs/Supervised-' + date_time) self.base_path = config['base_path'] self.stamp = config['stamp'] self.meta_epochs = config['num_meta_epochs'] self.early_stopping = config['early_stopping'] self.stopping_threshold = config.get('stopping_threshold', 1e-3) if 'seq' in config['meta_model']: self.model = SeqSupervisedNetwork(config) logger.info('Supervised network instantiated') def training(self, train_dataloader, val_dataloader, tags): best_loss = float('inf') best_f1 = 0 patience = 0 model_path = os.path.join(self.base_path, 'saved_models', 'SupervisedLearner-{}.h5'.format(self.stamp)) classifier_path = os.path.join(self.base_path, 'saved_models', 'SupervisedClassifier-{}.h5'.format(self.stamp)) logger.info('Model name: SupervisedLearner-{}.h5'.format(self.stamp)) for epoch in range(self.meta_epochs): logger.info('Starting epoch {}/{}'.format(epoch + 1, self.meta_epochs)) avg_loss, avg_accuracy, avg_precision, avg_recall, avg_f1 = self.model(train_dataloader, tags=tags, writer=self.tensorboard_writer) logger.info('Train epoch {}: Avg loss = {:.5f}, avg accuracy = {:.5f}, avg precision = {:.5f}, ' 'avg recall = {:.5f}, avg F1 score = {:.5f}'.format(epoch + 1, avg_loss, avg_accuracy, avg_precision, avg_recall, avg_f1)) self.tensorboard_writer.add_scalar('Loss/train', avg_loss, global_step=epoch + 1) self.tensorboard_writer.add_scalar('F1/train', avg_f1, global_step=epoch + 1) avg_loss, avg_accuracy, avg_precision, avg_recall, avg_f1 = self.model(val_dataloader, tags=tags, testing=True) logger.info('Val epoch {}: Avg loss = {:.5f}, avg accuracy = {:.5f}, avg precision = {:.5f}, ' 'avg recall = {:.5f}, avg F1 score = {:.5f}'.format(epoch + 1, avg_loss, avg_accuracy, avg_precision, avg_recall, avg_f1)) self.tensorboard_writer.add_scalar('Loss/val', avg_loss, global_step=epoch + 1) self.tensorboard_writer.add_scalar('F1/val', avg_f1, global_step=epoch + 1) if avg_f1 > best_f1 + self.stopping_threshold: patience = 0 best_loss = avg_loss best_f1 = avg_f1 logger.info('Saving the model since the F1 improved') torch.save(self.model.learner.state_dict(), model_path) torch.save(self.model.classifier.state_dict(), classifier_path) logger.info('') else: patience += 1 logger.info('F1 did not improve') logger.info('') if patience == self.early_stopping: break # Log params and grads into tensorboard for name, param in self.model.named_parameters(): if param.requires_grad and param.grad is not None: self.tensorboard_writer.add_histogram('Params/' + name, param.data.view(-1), global_step=epoch + 1) self.tensorboard_writer.add_histogram('Grads/' + name, param.grad.data.view(-1), global_step=epoch + 1) self.model.learner.load_state_dict(torch.load(model_path)) self.model.classifier.load_state_dict(torch.load(classifier_path)) return best_f1 def testing(self, test_dataloader, tags): logger.info('---------- Supervised testing starts here ----------') _, accuracy, precision, recall, f1_score = self.model(test_dataloader, tags=tags, testing=True) logger.info('Avg meta-testing metrics: Accuracy = {:.5f}, precision = {:.5f}, recall = {:.5f}, ' 'F1 score = {:.5f}'.format(accuracy, precision, recall, f1_score)) return f1_score
class TrainManager: """ Manages training loop, validations, learning rate scheduling and early stopping.""" def __init__(self, model: SignModel, config: dict) -> None: """ Creates a new TrainManager for a model, specified as in configuration. :param model: torch module defining the model :param config: dictionary containing the training configurations """ train_config = config["training"] # files for logging and storing self.model_dir = make_model_dir(train_config["model_dir"], overwrite=train_config.get( "overwrite", False)) self.logger = make_logger(model_dir=self.model_dir) self.logging_freq = train_config.get("logging_freq", 100) self.valid_report_file = "{}/validations.txt".format(self.model_dir) self.tb_writer = SummaryWriter(log_dir=self.model_dir + "/tensorboard/") # input self.feature_size = (sum(config["data"]["feature_size"]) if isinstance( config["data"]["feature_size"], list) else config["data"]["feature_size"]) self.feature_size_features = config["data"]["feature_size_cnn"] self.dataset_version = config["data"].get("version", "phoenix_2014_trans") # model self.model = model self.txt_pad_index = self.model.txt_pad_index self.txt_bos_index = self.model.txt_bos_index self._log_parameters_list() # Check if we are doing only recognition or only translation or both self.do_recognition = (config["training"].get( "recognition_loss_weight", 1.0) > 0.0) self.do_translation = (config["training"].get( "translation_loss_weight", 1.0) > 0.0) # Get Recognition and Translation specific parameters if self.do_recognition: self._get_recognition_params(train_config=train_config) if self.do_translation: self._get_translation_params(train_config=train_config) # optimization self.last_best_lr = train_config.get("learning_rate", -1) self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8) self.clip_grad_fun = build_gradient_clipper(config=train_config) self.optimizer = build_optimizer(config=train_config, parameters=model.parameters()) self.batch_multiplier = train_config.get("batch_multiplier", 1) # validation & early stopping self.validation_freq = train_config.get("validation_freq", 100) self.num_valid_log = train_config.get("num_valid_log", 5) self.ckpt_queue = queue.Queue( maxsize=train_config.get("keep_last_ckpts", 5)) self.eval_metric = train_config.get("eval_metric", "bleu") if self.eval_metric not in ["bleu", "chrf", "wer", "rouge"]: raise ValueError("Invalid setting for 'eval_metric': {}".format( self.eval_metric)) self.early_stopping_metric = train_config.get("early_stopping_metric", "eval_metric") # if we schedule after BLEU/chrf, we want to maximize it, else minimize # early_stopping_metric decides on how to find the early stopping point: # ckpts are written when there's a new high/low score for this metric if self.early_stopping_metric in [ "ppl", "translation_loss", "recognition_loss", ]: self.minimize_metric = True elif self.early_stopping_metric == "eval_metric": if self.eval_metric in ["bleu", "chrf", "rouge"]: assert self.do_translation self.minimize_metric = False else: # eval metric that has to get minimized (not yet implemented) self.minimize_metric = True else: raise ValueError( "Invalid setting for 'early_stopping_metric': {}".format( self.early_stopping_metric)) # data_augmentation parameters self.frame_subsampling_ratio = config["data"].get( "frame_subsampling_ratio", None) self.random_frame_subsampling = config["data"].get( "random_frame_subsampling", None) self.random_frame_masking_ratio = config["data"].get( "random_frame_masking_ratio", None) # learning rate scheduling self.scheduler, self.scheduler_step_at = build_scheduler( config=train_config, scheduler_mode="min" if self.minimize_metric else "max", optimizer=self.optimizer, hidden_size=config["model"]["encoder"]["hidden_size"], ) # data & batch handling self.level = config["data"]["level"] if self.level not in ["word", "bpe", "char"]: raise ValueError("Invalid segmentation level': {}".format( self.level)) self.shuffle = train_config.get("shuffle", True) self.epochs = train_config["epochs"] self.batch_size = train_config["batch_size"] self.batch_type = train_config.get("batch_type", "sentence") self.eval_batch_size = train_config.get("eval_batch_size", self.batch_size) self.eval_batch_type = train_config.get("eval_batch_type", self.batch_type) self.use_cuda = train_config["use_cuda"] if self.use_cuda: self.model.cuda() if self.do_translation: self.translation_loss_function.cuda() if self.do_recognition: self.recognition_loss_function.cuda() # initialize training statistics self.steps = 0 # stop training if this flag is True by reaching learning rate minimum self.stop = False self.total_txt_tokens = 0 self.total_gls_tokens = 0 self.best_ckpt_iteration = 0 # initial values for best scores self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf self.best_all_ckpt_scores = {} # comparison function for scores self.is_best = ( lambda score: score < self.best_ckpt_score if self.minimize_metric else score > self.best_ckpt_score) # model parameters if "load_model" in train_config.keys(): model_load_path = train_config["load_model"] self.logger.info("Loading model from %s", model_load_path) reset_best_ckpt = train_config.get("reset_best_ckpt", False) reset_scheduler = train_config.get("reset_scheduler", False) reset_optimizer = train_config.get("reset_optimizer", False) self.init_from_checkpoint( model_load_path, reset_best_ckpt=reset_best_ckpt, reset_scheduler=reset_scheduler, reset_optimizer=reset_optimizer, ) def _get_recognition_params(self, train_config) -> None: # NOTE (Cihan): The blank label is the silence index in the gloss vocabulary. # There is an assertion in the GlossVocabulary class's __init__. # This is necessary to do TensorFlow decoding, as it is hardcoded # Currently it is hardcoded as 0. self.gls_silence_token = self.model.gls_vocab.stoi[SIL_TOKEN] assert self.gls_silence_token == 0 self.recognition_loss_function = torch.nn.CTCLoss( blank=self.gls_silence_token, zero_infinity=True) self.recognition_loss_weight = train_config.get( "recognition_loss_weight", 1.0) self.eval_recognition_beam_size = train_config.get( "eval_recognition_beam_size", 1) def _get_translation_params(self, train_config) -> None: self.label_smoothing = train_config.get("label_smoothing", 0.0) self.translation_loss_function = XentLoss( pad_index=self.txt_pad_index, smoothing=self.label_smoothing) self.translation_normalization_mode = train_config.get( "translation_normalization", "batch") if self.translation_normalization_mode not in ["batch", "tokens"]: raise ValueError("Invalid normalization {}.".format( self.translation_normalization_mode)) self.translation_loss_weight = train_config.get( "translation_loss_weight", 1.0) self.eval_translation_beam_size = train_config.get( "eval_translation_beam_size", 1) self.eval_translation_beam_alpha = train_config.get( "eval_translation_beam_alpha", -1) self.translation_max_output_length = train_config.get( "translation_max_output_length", None) def _save_checkpoint(self) -> None: """ Save the model's current parameters and the training state to a checkpoint. The training state contains the total number of training steps, the total number of training tokens, the best checkpoint score and iteration so far, and optimizer and scheduler states. """ model_path = "{}/{}.ckpt".format(self.model_dir, self.steps) state = { "steps": self.steps, "total_txt_tokens": self.total_txt_tokens if self.do_translation else 0, "total_gls_tokens": self.total_gls_tokens if self.do_recognition else 0, "best_ckpt_score": self.best_ckpt_score, "best_all_ckpt_scores": self.best_all_ckpt_scores, "best_ckpt_iteration": self.best_ckpt_iteration, "model_state": self.model.state_dict(), "optimizer_state": self.optimizer.state_dict(), "scheduler_state": self.scheduler.state_dict() if self.scheduler is not None else None, } torch.save(state, model_path) if self.ckpt_queue.full(): to_delete = self.ckpt_queue.get() # delete oldest ckpt try: os.remove(to_delete) except FileNotFoundError: self.logger.warning( "Wanted to delete old checkpoint %s but " "file does not exist.", to_delete, ) self.ckpt_queue.put(model_path) # create/modify symbolic link for best checkpoint symlink_update("{}.ckpt".format(self.steps), "{}/best.ckpt".format(self.model_dir)) def init_from_checkpoint( self, path: str, reset_best_ckpt: bool = False, reset_scheduler: bool = False, reset_optimizer: bool = False, ) -> None: """ Initialize the trainer from a given checkpoint file. This checkpoint file contains not only model parameters, but also scheduler and optimizer states, see `self._save_checkpoint`. :param path: path to checkpoint :param reset_best_ckpt: reset tracking of the best checkpoint, use for domain adaptation with a new dev set or when using a new metric for fine-tuning. :param reset_scheduler: reset the learning rate scheduler, and do not use the one stored in the checkpoint. :param reset_optimizer: reset the optimizer, and do not use the one stored in the checkpoint. """ model_checkpoint = load_checkpoint(path=path, use_cuda=self.use_cuda) # restore model and optimizer parameters self.model.load_state_dict(model_checkpoint["model_state"]) if not reset_optimizer: self.optimizer.load_state_dict(model_checkpoint["optimizer_state"]) else: self.logger.info("Reset optimizer.") if not reset_scheduler: if (model_checkpoint["scheduler_state"] is not None and self.scheduler is not None): self.scheduler.load_state_dict( model_checkpoint["scheduler_state"]) else: self.logger.info("Reset scheduler.") # restore counts self.steps = model_checkpoint["steps"] self.total_txt_tokens = model_checkpoint["total_txt_tokens"] self.total_gls_tokens = model_checkpoint["total_gls_tokens"] if not reset_best_ckpt: self.best_ckpt_score = model_checkpoint["best_ckpt_score"] self.best_all_ckpt_scores = model_checkpoint[ "best_all_ckpt_scores"] self.best_ckpt_iteration = model_checkpoint["best_ckpt_iteration"] else: self.logger.info("Reset tracking of the best checkpoint.") # move parameters to cuda if self.use_cuda: self.model.cuda() def train_and_validate(self, train_data: Dataset, valid_data: Dataset) -> None: """ Train the model and validate it from time to time on the validation set. :param train_data: training data :param valid_data: validation data """ train_iter = make_data_iter( train_data, batch_size=self.batch_size, batch_type=self.batch_type, train=True, shuffle=self.shuffle, ) epoch_no = None for epoch_no in range(self.epochs): self.logger.info("EPOCH %d", epoch_no + 1) if self.scheduler is not None and self.scheduler_step_at == "epoch": self.scheduler.step(epoch=epoch_no) self.model.train() start = time.time() total_valid_duration = 0 count = self.batch_multiplier - 1 if self.do_recognition: processed_gls_tokens = self.total_gls_tokens epoch_recognition_loss = 0 if self.do_translation: processed_txt_tokens = self.total_txt_tokens epoch_translation_loss = 0 for batch in iter(train_iter): # reactivate training # create a Batch object from torchtext batch batch = Batch( is_train=True, torch_batch=batch, txt_pad_index=self.txt_pad_index, sgn_dim=self.feature_size, features_dim=self.feature_size_features, use_cuda=self.use_cuda, frame_subsampling_ratio=self.frame_subsampling_ratio, random_frame_subsampling=self.random_frame_subsampling, random_frame_masking_ratio=self.random_frame_masking_ratio, ) # only update every batch_multiplier batches # see https://medium.com/@davidlmorton/ # increasing-mini-batch-size-without-increasing- # memory-6794e10db672 update = count == 0 recognition_loss, translation_loss = self._train_batch( batch, update=update) if self.do_recognition: self.tb_writer.add_scalar("train/train_recognition_loss", recognition_loss, self.steps) epoch_recognition_loss += recognition_loss.detach().cpu( ).numpy() if self.do_translation: self.tb_writer.add_scalar("train/train_translation_loss", translation_loss, self.steps) epoch_translation_loss += translation_loss.detach().cpu( ).numpy() count = self.batch_multiplier if update else count count -= 1 if (self.scheduler is not None and self.scheduler_step_at == "step" and update): self.scheduler.step() # log learning progress if self.steps % self.logging_freq == 0 and update: elapsed = time.time() - start - total_valid_duration log_out = "[Epoch: {:03d} Step: {:08d}] ".format( epoch_no + 1, self.steps, ) if self.do_recognition: elapsed_gls_tokens = (self.total_gls_tokens - processed_gls_tokens) processed_gls_tokens = self.total_gls_tokens log_out += "Batch Recognition Loss: {:10.6f} => ".format( recognition_loss) log_out += "Gls Tokens per Sec: {:8.0f} || ".format( elapsed_gls_tokens / elapsed) if self.do_translation: elapsed_txt_tokens = (self.total_txt_tokens - processed_txt_tokens) processed_txt_tokens = self.total_txt_tokens log_out += "Batch Translation Loss: {:10.6f} => ".format( translation_loss) log_out += "Txt Tokens per Sec: {:8.0f} || ".format( elapsed_txt_tokens / elapsed) log_out += "Lr: {:.6f}".format( self.optimizer.param_groups[0]["lr"]) self.logger.info(log_out) start = time.time() total_valid_duration = 0 # validate on the entire dev set if self.steps % self.validation_freq == 0 and update: valid_start_time = time.time() # TODO (Cihan): There must be a better way of passing # these recognition only and translation only parameters! # Maybe have a NamedTuple with optional fields? # Hmm... Future Cihan's problem. val_res = validate_on_data( model=self.model, data=valid_data, batch_size=self.eval_batch_size, use_cuda=self.use_cuda, batch_type=self.eval_batch_type, dataset_version=self.dataset_version, sgn_dim=self.feature_size, features_dim=self.feature_size_features, txt_pad_index=self.txt_pad_index, # Recognition Parameters do_recognition=self.do_recognition, recognition_loss_function=self. recognition_loss_function if self.do_recognition else None, recognition_loss_weight=self.recognition_loss_weight if self.do_recognition else None, recognition_beam_size=self.eval_recognition_beam_size if self.do_recognition else None, # Translation Parameters do_translation=self.do_translation, translation_loss_function=self. translation_loss_function if self.do_translation else None, translation_max_output_length=self. translation_max_output_length if self.do_translation else None, level=self.level if self.do_translation else None, translation_loss_weight=self.translation_loss_weight if self.do_translation else None, translation_beam_size=self.eval_translation_beam_size if self.do_translation else None, translation_beam_alpha=self.eval_translation_beam_alpha if self.do_translation else None, frame_subsampling_ratio=self.frame_subsampling_ratio, ) self.model.train() if self.do_recognition: # Log Losses and ppl self.tb_writer.add_scalar( "valid/valid_recognition_loss", val_res["valid_recognition_loss"], self.steps, ) self.tb_writer.add_scalar( "valid/wer", val_res["valid_scores"]["wer"], self.steps) self.tb_writer.add_scalars( "valid/wer_scores", val_res["valid_scores"]["wer_scores"], self.steps, ) if self.do_translation: self.tb_writer.add_scalar( "valid/valid_translation_loss", val_res["valid_translation_loss"], self.steps, ) self.tb_writer.add_scalar("valid/valid_ppl", val_res["valid_ppl"], self.steps) # Log Scores self.tb_writer.add_scalar( "valid/chrf", val_res["valid_scores"]["chrf"], self.steps) self.tb_writer.add_scalar( "valid/rouge", val_res["valid_scores"]["rouge"], self.steps) self.tb_writer.add_scalar( "valid/bleu", val_res["valid_scores"]["bleu"], self.steps) self.tb_writer.add_scalars( "valid/bleu_scores", val_res["valid_scores"]["bleu_scores"], self.steps, ) if self.early_stopping_metric == "recognition_loss": assert self.do_recognition ckpt_score = val_res["valid_recognition_loss"] elif self.early_stopping_metric == "translation_loss": assert self.do_translation ckpt_score = val_res["valid_translation_loss"] elif self.early_stopping_metric in ["ppl", "perplexity"]: assert self.do_translation ckpt_score = val_res["valid_ppl"] else: ckpt_score = val_res["valid_scores"][self.eval_metric] new_best = False if self.is_best(ckpt_score): self.best_ckpt_score = ckpt_score self.best_all_ckpt_scores = val_res["valid_scores"] self.best_ckpt_iteration = self.steps self.logger.info( "Hooray! New best validation result [%s]!", self.early_stopping_metric, ) if self.ckpt_queue.maxsize > 0: self.logger.info("Saving new checkpoint.") new_best = True self._save_checkpoint() if (self.scheduler is not None and self.scheduler_step_at == "validation"): prev_lr = self.scheduler.optimizer.param_groups[0][ "lr"] self.scheduler.step(ckpt_score) now_lr = self.scheduler.optimizer.param_groups[0]["lr"] if prev_lr != now_lr: if self.last_best_lr != prev_lr: self.stop = True # append to validation report self._add_report( valid_scores=val_res["valid_scores"], valid_recognition_loss=val_res["valid_recognition_loss"] if self.do_recognition else None, valid_translation_loss=val_res["valid_translation_loss"] if self.do_translation else None, valid_ppl=val_res["valid_ppl"] if self.do_translation else None, eval_metric=self.eval_metric, new_best=new_best, ) valid_duration = time.time() - valid_start_time total_valid_duration += valid_duration self.logger.info( "Validation result at epoch %3d, step %8d: duration: %.4fs\n\t" "Recognition Beam Size: %d\t" "Translation Beam Size: %d\t" "Translation Beam Alpha: %d\n\t" "Recognition Loss: %4.5f\t" "Translation Loss: %4.5f\t" "PPL: %4.5f\n\t" "Eval Metric: %s\n\t" "WER %3.2f\t(DEL: %3.2f,\tINS: %3.2f,\tSUB: %3.2f)\n\t" "BLEU-4 %.2f\t(BLEU-1: %.2f,\tBLEU-2: %.2f,\tBLEU-3: %.2f,\tBLEU-4: %.2f)\n\t" "CHRF %.2f\t" "ROUGE %.2f", epoch_no + 1, self.steps, valid_duration, self.eval_recognition_beam_size if self.do_recognition else -1, self.eval_translation_beam_size if self.do_translation else -1, self.eval_translation_beam_alpha if self.do_translation else -1, val_res["valid_recognition_loss"] if self.do_recognition else -1, val_res["valid_translation_loss"] if self.do_translation else -1, val_res["valid_ppl"] if self.do_translation else -1, self.eval_metric.upper(), # WER val_res["valid_scores"]["wer"] if self.do_recognition else -1, val_res["valid_scores"]["wer_scores"]["del_rate"] if self.do_recognition else -1, val_res["valid_scores"]["wer_scores"]["ins_rate"] if self.do_recognition else -1, val_res["valid_scores"]["wer_scores"]["sub_rate"] if self.do_recognition else -1, # BLEU val_res["valid_scores"]["bleu"] if self.do_translation else -1, val_res["valid_scores"]["bleu_scores"]["bleu1"] if self.do_translation else -1, val_res["valid_scores"]["bleu_scores"]["bleu2"] if self.do_translation else -1, val_res["valid_scores"]["bleu_scores"]["bleu3"] if self.do_translation else -1, val_res["valid_scores"]["bleu_scores"]["bleu4"] if self.do_translation else -1, # Other val_res["valid_scores"]["chrf"] if self.do_translation else -1, val_res["valid_scores"]["rouge"] if self.do_translation else -1, ) self._log_examples( sequences=[s for s in valid_data.sequence], gls_references=val_res["gls_ref"] if self.do_recognition else None, gls_hypotheses=val_res["gls_hyp"] if self.do_recognition else None, txt_references=val_res["txt_ref"] if self.do_translation else None, txt_hypotheses=val_res["txt_hyp"] if self.do_translation else None, ) valid_seq = [s for s in valid_data.sequence] # store validation set outputs and references if self.do_recognition: self._store_outputs("dev.hyp.gls", valid_seq, val_res["gls_hyp"], "gls") self._store_outputs("references.dev.gls", valid_seq, val_res["gls_ref"]) if self.do_translation: self._store_outputs("dev.hyp.txt", valid_seq, val_res["txt_hyp"], "txt") self._store_outputs("references.dev.txt", valid_seq, val_res["txt_ref"]) if self.stop: break if self.stop: if (self.scheduler is not None and self.scheduler_step_at == "validation" and self.last_best_lr != prev_lr): self.logger.info( "Training ended since there were no improvements in" "the last learning rate step: %f", prev_lr, ) else: self.logger.info( "Training ended since minimum lr %f was reached.", self.learning_rate_min, ) break self.logger.info( "Epoch %3d: Total Training Recognition Loss %.2f " " Total Training Translation Loss %.2f ", epoch_no + 1, epoch_recognition_loss if self.do_recognition else -1, epoch_translation_loss if self.do_translation else -1, ) else: self.logger.info("Training ended after %3d epochs.", epoch_no + 1) self.logger.info( "Best validation result at step %8d: %6.2f %s.", self.best_ckpt_iteration, self.best_ckpt_score, self.early_stopping_metric, ) self.tb_writer.close() # close Tensorboard writer def _train_batch(self, batch: Batch, update: bool = True) -> (Tensor, Tensor): """ Train the model on one batch: Compute the loss, make a gradient step. :param batch: training batch :param update: if False, only store gradient. if True also make update :return normalized_recognition_loss: Normalized recognition loss :return normalized_translation_loss: Normalized translation loss """ recognition_loss, translation_loss = self.model.get_loss_for_batch( batch=batch, recognition_loss_function=self.recognition_loss_function if self.do_recognition else None, translation_loss_function=self.translation_loss_function if self.do_translation else None, recognition_loss_weight=self.recognition_loss_weight if self.do_recognition else None, translation_loss_weight=self.translation_loss_weight if self.do_translation else None, ) # normalize translation loss if self.do_translation: if self.translation_normalization_mode == "batch": txt_normalization_factor = batch.num_seqs elif self.translation_normalization_mode == "tokens": txt_normalization_factor = batch.num_txt_tokens else: raise NotImplementedError( "Only normalize by 'batch' or 'tokens'") # division needed since loss.backward sums the gradients until updated normalized_translation_loss = translation_loss / ( txt_normalization_factor * self.batch_multiplier) else: normalized_translation_loss = 0 # TODO (Cihan): Add Gloss Token normalization (?) # I think they are already being normalized by batch # I need to think about if I want to normalize them by token. if self.do_recognition: normalized_recognition_loss = recognition_loss / self.batch_multiplier else: normalized_recognition_loss = 0 total_loss = normalized_recognition_loss + normalized_translation_loss # compute gradients total_loss.backward() if self.clip_grad_fun is not None: # clip gradients (in-place) self.clip_grad_fun(params=self.model.parameters()) if update: # make gradient step self.optimizer.step() self.optimizer.zero_grad() # increment step counter self.steps += 1 # increment token counter if self.do_recognition: self.total_gls_tokens += batch.num_gls_tokens if self.do_translation: self.total_txt_tokens += batch.num_txt_tokens return normalized_recognition_loss, normalized_translation_loss def _add_report( self, valid_scores: Dict, valid_recognition_loss: float, valid_translation_loss: float, valid_ppl: float, eval_metric: str, new_best: bool = False, ) -> None: """ Append a one-line report to validation logging file. :param valid_scores: Dictionary of validation scores :param valid_recognition_loss: validation loss (sum over whole validation set) :param valid_translation_loss: validation loss (sum over whole validation set) :param valid_ppl: validation perplexity :param eval_metric: evaluation metric, e.g. "bleu" :param new_best: whether this is a new best model """ current_lr = -1 # ignores other param groups for now for param_group in self.optimizer.param_groups: current_lr = param_group["lr"] if new_best: self.last_best_lr = current_lr if current_lr < self.learning_rate_min: self.stop = True with open(self.valid_report_file, "a", encoding="utf-8") as opened_file: opened_file.write( "Steps: {}\t" "Recognition Loss: {:.5f}\t" "Translation Loss: {:.5f}\t" "PPL: {:.5f}\t" "Eval Metric: {}\t" "WER {:.2f}\t(DEL: {:.2f},\tINS: {:.2f},\tSUB: {:.2f})\t" "BLEU-4 {:.2f}\t(BLEU-1: {:.2f},\tBLEU-2: {:.2f},\tBLEU-3: {:.2f},\tBLEU-4: {:.2f})\t" "CHRF {:.2f}\t" "ROUGE {:.2f}\t" "LR: {:.8f}\t{}\n".format( self.steps, valid_recognition_loss if self.do_recognition else -1, valid_translation_loss if self.do_translation else -1, valid_ppl if self.do_translation else -1, eval_metric, # WER valid_scores["wer"] if self.do_recognition else -1, valid_scores["wer_scores"]["del_rate"] if self.do_recognition else -1, valid_scores["wer_scores"]["ins_rate"] if self.do_recognition else -1, valid_scores["wer_scores"]["sub_rate"] if self.do_recognition else -1, # BLEU valid_scores["bleu"] if self.do_translation else -1, valid_scores["bleu_scores"]["bleu1"] if self.do_translation else -1, valid_scores["bleu_scores"]["bleu2"] if self.do_translation else -1, valid_scores["bleu_scores"]["bleu3"] if self.do_translation else -1, valid_scores["bleu_scores"]["bleu4"] if self.do_translation else -1, # Other valid_scores["chrf"] if self.do_translation else -1, valid_scores["rouge"] if self.do_translation else -1, current_lr, "*" if new_best else "", )) def _log_parameters_list(self) -> None: """ Write all model parameters (name, shape) to the log. """ model_parameters = filter(lambda p: p.requires_grad, self.model.parameters()) n_params = sum([np.prod(p.size()) for p in model_parameters]) self.logger.info("Total params: %d", n_params) trainable_params = [ n for (n, p) in self.model.named_parameters() if p.requires_grad ] self.logger.info("Trainable parameters: %s", sorted(trainable_params)) assert trainable_params def _log_examples( self, sequences: List[str], gls_references: List[str], gls_hypotheses: List[str], txt_references: List[str], txt_hypotheses: List[str], ) -> None: """ Log `self.num_valid_log` number of samples from valid. :param sequences: sign video sequence names (list of strings) :param txt_hypotheses: decoded txt hypotheses (list of strings) :param txt_references: decoded txt references (list of strings) :param gls_hypotheses: decoded gls hypotheses (list of strings) :param gls_references: decoded gls references (list of strings) """ if self.do_recognition: assert len(gls_references) == len(gls_hypotheses) num_sequences = len(gls_hypotheses) if self.do_translation: assert len(txt_references) == len(txt_hypotheses) num_sequences = len(txt_hypotheses) rand_idx = np.sort( np.random.permutation(num_sequences)[:self.num_valid_log]) self.logger.info("Logging Recognition and Translation Outputs") self.logger.info("=" * 120) for ri in rand_idx: self.logger.info("Logging Sequence: %s", sequences[ri]) if self.do_recognition: gls_res = wer_single(r=gls_references[ri], h=gls_hypotheses[ri]) self.logger.info("\tGloss Reference :\t%s", gls_res["alignment_out"]["align_ref"]) self.logger.info("\tGloss Hypothesis:\t%s", gls_res["alignment_out"]["align_hyp"]) self.logger.info("\tGloss Alignment :\t%s", gls_res["alignment_out"]["alignment"]) if self.do_recognition and self.do_translation: self.logger.info("\t" + "-" * 116) if self.do_translation: txt_res = wer_single(r=txt_references[ri], h=txt_hypotheses[ri]) self.logger.info("\tText Reference :\t%s", txt_res["alignment_out"]["align_ref"]) self.logger.info("\tText Hypothesis :\t%s", txt_res["alignment_out"]["align_hyp"]) self.logger.info("\tText Alignment :\t%s", txt_res["alignment_out"]["alignment"]) self.logger.info("=" * 120) def _store_outputs(self, tag: str, sequence_ids: List[str], hypotheses: List[str], sub_folder=None) -> None: """ Write current validation outputs to file in `self.model_dir.` :param hypotheses: list of strings """ if sub_folder: out_folder = os.path.join(self.model_dir, sub_folder) if not os.path.exists(out_folder): os.makedirs(out_folder) current_valid_output_file = "{}/{}.{}".format( out_folder, self.steps, tag) else: out_folder = self.model_dir current_valid_output_file = "{}/{}".format(out_folder, tag) with open(current_valid_output_file, "w", encoding="utf-8") as opened_file: for seq, hyp in zip(sequence_ids, hypotheses): opened_file.write("{}|{}\n".format(seq, hyp))
# TRY NOT TO MODIFY: execute the game and log data. next_obs, reward, done, _ = env.step(action) episode_reward += reward # ALGO LOGIC: training. rb.put((obs, action, reward, next_obs, done)) if global_step > args.learning_starts and global_step % args.train_frequency == 0: s_obs, s_actions, s_rewards, s_next_obses, s_dones = rb.sample(args.batch_size) with torch.no_grad(): target_max = torch.max(target_network.forward(s_next_obses, device), dim=1)[0] td_target = torch.Tensor(s_rewards).to(device) + args.gamma * target_max * (1 - torch.Tensor(s_dones).to(device)) old_val = q_network.forward(s_obs, device).gather(1, torch.LongTensor(s_actions).view(-1,1).to(device)).squeeze() loss = loss_fn(td_target, old_val) if global_step % 100 == 0: writer.add_scalar("losses/td_loss", loss, global_step) # optimize the midel optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(list(q_network.parameters()), args.max_grad_norm) optimizer.step() # update the target network if global_step % args.target_network_frequency == 0: target_network.load_state_dict(q_network.state_dict()) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs if done:
def train_melgan(args): root = Path(args.save_path) load_root = Path(args.load_path) if args.load_path else None root.mkdir(parents=True, exist_ok=True) metadata_dir = root.joinpath('metadata') metadata_dir.mkdir(exist_ok=True) #################################### # Dump arguments and create logger # #################################### with open(metadata_dir / "args.yml", "w") as f: yaml.dump(args.__dict__, f) with open(metadata_dir / "args.json", "w", encoding="utf8") as f: json.dump(args.__dict__, f, indent=4, ensure_ascii=False) eventdir = root / "events" eventdir.mkdir(exist_ok=True) writer = SummaryWriter(str(eventdir)) ####################### # Load PyTorch Models # ####################### ratios = [int(w) for w in args.ratios.split()] netG = Generator(args.n_mel_channels, args.ngf, args.n_residual_layers, ratios=ratios).to(_device) netD = Discriminator(args.num_D, args.ndf, args.n_layers_D, args.downsamp_factor).to(_device) # fft = Audio2Mel(n_mel_channels=args.n_mel_channels).to(_device) if args.mode == 'default': fft = audio2mel elif args.mode == 'synthesizer': fft = audio2mel_synthesizer elif args.mode == 'mellotron': fft = audio2mel_mellotron else: raise KeyError # print(netG) # print(netD) ##################### # Create optimizers # ##################### optG = torch.optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9)) optD = torch.optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9)) if load_root and load_root.exists(): netG.load_state_dict(torch.load(load_root)) # optG.load_state_dict(torch.load(load_root / "optG.pt")) # netD.load_state_dict(torch.load(load_root / "netD.pt")) # optD.load_state_dict(torch.load(load_root / "optD.pt")) ####################### # Create data loaders # ####################### train_set = AudioDataset(Path(args.data_path), args.seq_len, sampling_rate=args.sample_rate) test_set = AudioDataset( Path(args.data_path), # test file args.sample_rate * 4, sampling_rate=args.sample_rate, augment=False, ) # 保存训练数据 with open(metadata_dir.joinpath('train.yml'), 'wt', encoding='utf8') as fout: yaml.dump([str(w.absolute()) for w in train_set.audio_files], fout, default_flow_style=False, encoding='utf-8', allow_unicode=True) train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=args.dataloader_num_workers, shuffle=True) test_loader = DataLoader(test_set, batch_size=1, shuffle=True) ########################## # Dumping original audio # ########################## test_voc = [] test_audio = [] for i, x_t in enumerate(test_loader): x_t = x_t.to(_device) s_t = fft(x_t).detach() test_voc.append(s_t.to(_device)) test_audio.append(x_t) audio = x_t.squeeze().cpu() oridir = root / "original" oridir.mkdir(exist_ok=True) save_sample(oridir / ("original_{}_{}.wav".format("test", i)), args.sample_rate, audio) writer.add_audio("original/{}/sample_{}.wav".format("test", i), audio, 0, sample_rate=args.sample_rate) mel_outputs = fft(x_t) writer.add_image("original/{}/sample_{}.npy".format("test", i), plot_spectrogram_to_numpy( mel_outputs[0].data.cpu().numpy()), 0, dataformats='HWC') if i == args.n_test_samples - 1: break costs = [] start = time.time() # enable cudnn autotuner to speed up training torch.backends.cudnn.benchmark = True best_mel_reconst = 1000000 step_begin = args.start_step look_steps = { step_begin + 10, step_begin + 100, step_begin + 1000, step_begin + 10000 } steps = step_begin for epoch in range(1, args.epochs + 1): print("\nEpoch {} beginning. Current step: {}".format(epoch, steps)) for iterno, x_t in enumerate( tqdm(train_loader, desc=f"Epoch-{epoch}", ncols=100)): # torch.Size([4, 1, 8192]) torch.Size([4, 80, 32]) # 8192 = 32 x 256 x_t = x_t.to(_device) s_t = fft(x_t).detach() x_pred_t = netG(s_t.to(_device)) with torch.no_grad(): s_pred_t = fft(x_pred_t.detach()) s_error = F.l1_loss(s_t, s_pred_t).item() ####################### # Train Discriminator # ####################### D_fake_det = netD(x_pred_t.to(_device).detach()) D_real = netD(x_t.to(_device)) loss_D = 0 for scale in D_fake_det: loss_D += F.relu(1 + scale[-1]).mean() for scale in D_real: loss_D += F.relu(1 - scale[-1]).mean() netD.zero_grad() loss_D.backward() optD.step() ################### # Train Generator # ################### D_fake = netD(x_pred_t.to(_device)) loss_G = 0 for scale in D_fake: loss_G += -scale[-1].mean() loss_feat = 0 feat_weights = 4.0 / (args.n_layers_D + 1) D_weights = 1.0 / args.num_D wt = D_weights * feat_weights for i in range(args.num_D): for j in range(len(D_fake[i]) - 1): loss_feat += wt * F.l1_loss(D_fake[i][j], D_real[i][j].detach()) netG.zero_grad() (loss_G + args.lambda_feat * loss_feat).backward() optG.step() ###################### # Update tensorboard # ###################### costs.append( [loss_D.item(), loss_G.item(), loss_feat.item(), s_error]) steps += 1 writer.add_scalar("loss/discriminator", costs[-1][0], steps) writer.add_scalar("loss/generator", costs[-1][1], steps) writer.add_scalar("loss/feature_matching", costs[-1][2], steps) writer.add_scalar("loss/mel_reconstruction", costs[-1][3], steps) if steps % args.save_interval == 0 or steps in look_steps: st = time.time() with torch.no_grad(): for i, (voc, true_audio) in enumerate(zip(test_voc, test_audio)): pred_audio_ = netG(voc) pred_audio = pred_audio_.squeeze().cpu() gendir = root / "generated" gendir.mkdir(exist_ok=True) save_sample( gendir / ("generated_step{}_{}.wav".format(steps, i)), args.sample_rate, pred_audio) writer.add_audio( "generated/step{}/sample_{}.wav".format(steps, i), pred_audio, epoch, sample_rate=args.sample_rate, ) # 查看频谱,直观了解生成语音的情况 mel_outputs = fft(pred_audio_.detach()) writer.add_image( "generated/step{}/sample_{}.npy".format(steps, i), plot_spectrogram_to_numpy( mel_outputs[0].data.cpu().numpy()), epoch, dataformats='HWC') ptdir = root / "models" ptdir.mkdir(exist_ok=True) torch.save(netG.state_dict(), ptdir / "step{}_netG.pt".format(steps)) torch.save(optG.state_dict(), ptdir / "step{}_optG.pt".format(steps)) torch.save(netD.state_dict(), ptdir / "step{}_netD.pt".format(steps)) torch.save(optD.state_dict(), ptdir / "step{}_optD.pt".format(steps)) if (np.asarray(costs).mean(0)[-1] < best_mel_reconst) or ( steps % (args.save_interval * 10) == 0): best_mel_reconst = np.asarray(costs).mean(0)[-1] torch.save(netD, ptdir / "best_step{}_netD.pt".format(steps)) torch.save(netG, ptdir / "best_step{}_netG.pt".format(steps)) # print("\nTook %5.4fs to generate samples" % (time.time() - st)) # print("-" * 100) if steps % args.log_interval == 0 or steps in look_steps: print( "\nEpoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}". format( epoch, iterno, len(train_loader), 1000 * (time.time() - start) / args.log_interval, np.asarray(costs).mean(0), )) costs = [] start = time.time()
def train_and_eval_model(args, model, fv, ks, train_dataloader, valid_dataloader, test_dataloader, device): ## Define loss criteria, optimizer and adaptive learning scheduler criterion = nn.MSELoss(reduction='mean') optimizer = optim.RMSprop(model.parameters(), lr=args.learning_rate, alpha=0.9) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1, threshold=0.01, verbose=True) writer = SummaryWriter( "runs/" + f"BS={args.batchsize}_maxEp={args.maxepoch}_LR={args.learning_rate}_ks={ks}_model={fv}" ) for epoch in range(1, args.maxepoch + 1): model.train() running_loss = 0.0 for step, batch in enumerate(train_dataloader): x_batch = batch[0].to(device) y_batch = batch[1].to(device) optimizer.zero_grad() # zero the parameter gradients loss = criterion(model(x_batch), y_batch) loss.backward() # backpropagate loss optimizer.step() # update parameters running_loss += loss.item() / args.nbands print("Epoch = " + str(epoch) + " :Total loss = " + str(running_loss)) scheduler.step(running_loss) #this adjusts the adaptive LR scheduler writer.add_scalar('Loss', running_loss, epoch) # log training stats with torch.no_grad(): model.eval() for step, batch in enumerate(valid_dataloader): x_batch = batch[0].to(device) y_batch = batch[1].to(device) pred_batch = model(x_batch) loss = criterion(pred_batch, y_batch) tloss = loss.item() floss = calfloss(pred_batch, y_batch) ploss = floss.item() writer.add_scalar('valid loss', tloss) writer.add_scalar('valid fractional loss', ploss) print('Total valid loss is ' + str(ploss)) for step, batch in enumerate(test_dataloader): x_batch = batch[0].to(device) y_batch = batch[1].to(device) pred_batch = model(x_batch) loss = criterion(pred_batch, y_batch) tloss = loss.item() floss = calfloss(pred_batch, y_batch) ploss = floss.item() writer.add_scalar('test loss', tloss) writer.add_scalar('test fractional loss', ploss) print('Total test loss is ' + str(ploss)) writer.close()
def train_network(): print('') print('') # Start measuring time - to evaluate performance of the training function start = timeit.default_timer() # Set seeds set_seed(args) # Make folders if not yet exist try: os.makedirs('save') except FileExistsError: pass # Save relevant arguments from args and set hardcoded arguments lr = args.lr # learning rate batch_size = args.batch_size # Mini-batch size num_epochs = args.num_epochs # Number of epochs to train the network seq_len = args.seq_len # Network architecture: rnn_name = args.rnn_name inputs_list = args.inputs_list outputs_list = args.outputs_list load_rnn = args.load_rnn # If specified this is the name of pretrained RNN which should be loaded path_save = args.path_save # Create rnn instance and update lists of input, outputs and its name (if pretraind net loaded) net, rnn_name, inputs_list, outputs_list \ = create_rnn_instance(rnn_name, inputs_list, outputs_list, load_rnn, path_save, device) # Create log for this RNN and determine its full name rnn_full_name = create_log_file(rnn_name, inputs_list, outputs_list, path_save) ######################################################## # Create Dataset ######################################################## train_features, train_targets = load_data(args, args.train_file_name, inputs_list, outputs_list) dev_features, dev_targets = load_data(args, args.val_file_name, inputs_list, outputs_list) train_set = Dataset(train_features, train_targets, args) dev_set = Dataset(dev_features, dev_targets, args) print('Number of samples in training set: {}'.format( train_set.number_of_samples)) print('The training sets sizes are: {}'.format(train_set.df_lengths)) print('Number of samples in validation set: {}'.format( dev_set.number_of_samples)) print('') plot_results( net=net, args=args, dataset=dev_set, filepath='../../data/oval_easy_12_rounds.csv', seq_len=400, comment='This is the network at the beginning of the training', inputs_list=inputs_list, outputs_list=outputs_list, rnn_full_name=rnn_full_name) # Create PyTorch dataloaders for train and dev set train_generator = data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=args.num_workers) dev_generator = data.DataLoader(dataset=dev_set, batch_size=512, shuffle=False, num_workers=args.num_workers) # Print parameter count print_parameter_count(net) # Seems not to function well # Select Optimizer optimizer = optim.Adam(net.parameters(), amsgrad=True, lr=lr) # TODO: Verify if scheduler is working. Try tweaking parameters of below scheduler and try cyclic lr scheduler # scheduler = lr_scheduler.CyclicLR(optimizer, base_lr=lr, max_lr=0.1) scheduler = lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5) # Select Loss Function criterion = nn.MSELoss() # Mean square error loss function ''' Init Tensorboard ''' comment = f' batch_size={batch_size} lr={lr} seq_len={seq_len}' tb = SummaryWriter(comment=comment) ######################################################## # Training ######################################################## print("Starting training...") print('') time.sleep(0.001) # Create dictionary to store training history dict_history = {} dict_history['epoch'] = [] dict_history['time'] = [] dict_history['lr'] = [] dict_history['train_loss'] = [] dict_history['dev_loss'] = [] dict_history['dev_gain'] = [] dict_history['test_loss'] = [] dev_gain = 1 # The epoch_saved variable will indicate from which epoch is the last RNN model, # which was good enough to be saved epoch_saved = -1 for epoch in range(num_epochs): ########################################################################################################### # Training - Iterate batches ########################################################################################################### # Set RNN in training mode net = net.train() # Define variables accumulating training loss and counting training batchs train_loss = 0 train_batches = 0 # Iterate training over available batches # tqdm() is just a function which displays the progress bar # Otherwise the line below is the same as "for batch, labels in train_generator:" for batch, labels in tqdm(train_generator): # Iterate through batches # Reset the network (internal states of hidden layers and output history not the weights!) net.reset() # Further modifying the input and output form to fit RNN requirements # If GPU available we send tensors to GPU (cuda) if torch.cuda.is_available(): batch = batch.float().cuda().transpose(0, 1) labels = labels.float().cuda() else: batch = batch.float().transpose(0, 1) labels = labels.float() # # Reset memory of gradients # optimizer.zero_grad() # Warm-up (open loop prediction) to settle the internal state of RNN hidden layers net(rnn_input=batch[:args.warm_up_len, :, :]) # Reset memory of gradients optimizer.zero_grad() # Forward propagation - These are the results from which we calculate the update to RNN weights # GRU Input size must be (seq_len, batch, input_size) net(rnn_input=batch[args.warm_up_len:, :, :]) out = net.return_outputs_history() # Get loss loss = criterion(out[:, args.warm_up_len:, :], labels[:, args.warm_up_len:, :]) # Backward propagation loss.backward() # Gradient clipping - prevent gradient from exploding torch.nn.utils.clip_grad_norm_(net.parameters(), 100) # Update parameters optimizer.step() scheduler.step() # Update variables for loss calculation batch_loss = loss.detach() train_loss += batch_loss # Accumulate loss train_batches += 1 # Accumulate count so we can calculate mean later ########################################################################################################### # Validation - Iterate batches ########################################################################################################### # Set the network in evaluation mode net = net.eval() # Define variables accumulating evaluation loss and counting evaluation batches dev_loss = 0 dev_batches = 0 for (batch, labels) in tqdm(dev_generator): # Reset the network (internal states of hidden layers and output history not the weights!) net.reset() # Further modifying the input and output form to fit RNN requirements # If GPU available we send tensors to GPU (cuda) if torch.cuda.is_available(): batch = batch.float().cuda().transpose(0, 1) labels = labels.float().cuda() else: batch = batch.float().transpose(0, 1) labels = labels.float() # Warm-up (open loop prediction) to settle the internal state of RNN hidden layers net(rnn_input=batch) out = net.return_outputs_history() # Get loss # For evaluation we always calculate loss over the whole maximal prediction period # This allow us to compare RNN models from different epochs loss = criterion(out[:, args.warm_up_len:args.seq_len], labels[:, args.warm_up_len:args.seq_len]) # Update variables for loss calculation batch_loss = loss.detach() dev_loss += batch_loss # Accumulate loss dev_batches += 1 # Accumulate count so we can calculate mean later # Reset the network (internal states of hidden layers and output history not the weights!) net.reset() # Get current learning rate # TODO(Fixed. It does changes now): I think now the learning rate do not change during traing, or it is not a right way to get this info. for param_group in optimizer.param_groups: lr_curr = param_group['lr'] ''' Add data for tensorboard TODO : Add network graph and I/O to tensorboard ''' # tb.add_graph(net) tb.add_scalar('Train Loss', train_loss / train_batches, epoch) tb.add_scalar('Dev Loss', dev_loss / dev_batches, epoch) # Add the first sample of batch to tensorboard. Prediction is represented by Dotted line # TODO: Concatenate such graphs. But they are not continous for i in range(labels.shape[2]): time_label = np.arange(0, labels.shape[1], 1) time_out = np.arange(0, out.shape[1], 1) true_data = labels[1, :, i] predicted_data = out[1, :, i] fig_tb = plt.figure(5) plt.plot(time_label, true_data.detach().cpu()) plt.plot(time_out, predicted_data.detach().cpu(), linestyle='dashed') tb.add_figure(tag=str(args.outputs_list[i]), figure=fig_tb, global_step=epoch) for name, param in net.named_parameters(): tb.add_histogram(name, param, epoch) tb.add_histogram(f'{name}.grad', param.grad, epoch) tb.close() # Write the summary information about the training for the just completed epoch to a dictionary dict_history['epoch'].append(epoch) dict_history['lr'].append(lr_curr) dict_history['train_loss'].append(train_loss.detach().cpu().numpy() / train_batches / (args.seq_len - args.warm_up_len)) dict_history['dev_loss'].append(dev_loss.detach().cpu().numpy() / dev_batches / (args.seq_len - args.warm_up_len)) # Get relative loss gain for network evaluation if epoch >= 1: dev_gain = (dict_history['dev_loss'][epoch - 1] - dict_history['dev_loss'][epoch]) / \ dict_history['dev_loss'][epoch - 1] dict_history['dev_gain'].append(dev_gain) # Print the summary information about the training for the just completed epoch print('\nEpoch: %3d of %3d | ' 'LR: %1.5f | ' 'Train-L: %6.4f | ' 'Val-L: %6.4f | ' 'Val-Gain: %3.2f |' % (dict_history['epoch'][epoch], num_epochs - 1, dict_history['lr'][epoch], dict_history['train_loss'][epoch], dict_history['dev_loss'][epoch], dict_history['dev_gain'][epoch] * 100)) print('') # Save the best model with the lowest dev loss # Always save the model from epoch 0 # TODO: this is a bug: you should only save the model from epoch 0 if there is no pretraind network if epoch == 0: min_dev_loss = dev_loss # If current loss smaller equal than minimal till now achieved loss, # save the current RNN model and save its loss as minimal ever achieved if dev_loss <= min_dev_loss: epoch_saved = epoch min_dev_loss = dev_loss torch.save(net.state_dict(), args.path_save + rnn_full_name + '.pt', _use_new_zipfile_serialization=False) print('>>> saving best model from epoch {}'.format(epoch)) print('') else: print('>>> We keep model from epoch {}'.format(epoch_saved)) print('') plot_string = 'This is the network after {} training epoch'.format( epoch + 1) plot_results(net=net, args=args, dataset=dev_set, filepath='../../data/oval_easy_12_rounds.csv', seq_len=600, comment=plot_string, inputs_list=inputs_list, outputs_list=outputs_list, rnn_full_name=rnn_full_name) # Evaluate the performance of the current network # by checking its predictions on a randomly generated CartPole experiment # plot_results(net, args, val_file) # When finished the training print the final message print( "Training Completed... ") print(" ") # Calculate the total time it took to run the function stop = timeit.default_timer() total_time = stop - start # Return the total time it took to run the function return total_time
def train(args, train_dataset, model, tokenizer): """ Train the model """ if args.local_rank in [-1, 0]: tensorboard_log_dir = os.path.join( "tensorboard", args.task_name, args.data_dir, "_".join([ args.model_name_or_path, str(args.max_seq_length), str( max(1, args.n_gpu) * args.gradient_accumulation_steps * args.per_gpu_train_batch_size), str(args.learning_rate), str(args.weight_decay), str(args.warmup_steps) ]), str(args.seed)) logger.info("Tensorboard dir: %s", tensorboard_log_dir) tb_writer = SummaryWriter(log_dir=tensorboard_log_dir) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if args.resume: opt_path = os.path.join(args.model_name_or_path, "optimizer.pt") sch_path = os.path.join(args.model_name_or_path, "scheduler.pt") if os.path.isfile(opt_path) and os.path.isfile(sch_path): # Load in optimizer and scheduler states optimizer.load_state_dict(torch.load(opt_path)) scheduler.load_state_dict(torch.load(sch_path)) else: raise RuntimeError( f"--resume was set but there are no optimizer and scheduler states at {opt_path} and {sch_path}" ) else: logger.info( "Not checking for optimizer and scheduler state as --resume was not set. Starting afresh" ) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if args.resume: if not args.global_step: raise ValueError( "--global_step (int) has to be set when using --resume") global_step = args.global_step epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) # logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange( epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], ) set_seed(args) # Added here for reproductibility for _ in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue model.train() batch = tuple(t.to(args.device) for t in batch) inputs = { "input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3] } if args.model_type != "distilbert": inputs["token_type_ids"] = ( batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids outputs = model(**inputs) loss = outputs[ 0] # model outputs are always tuple in transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} if ( args.local_rank == -1 and args.evaluate_during_training ): # Only evaluate when single GPU otherwise metrics may not average well results = evaluate(args, model, tokenizer) for key, value in results.items(): logs[key] = value loss_scalar = (tr_loss - logging_loss) / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["lr"] = learning_rate_scalar logs["train_loss"] = loss_scalar logging_loss = tr_loss logger.info( "Performance at global step: %s", str(global_step), ) for key, value in logs.items(): logger.info(" %s = %s", key, str(value)) tb_writer.add_scalar(key, value, global_step) if args.wandb: wandb_log({**logs, **{"step": global_step}}) if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join( args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank in [-1, 0]: tb_writer.close() return global_step, tr_loss / global_step
for i in range(1, motor_input.shape[0]): visual_output, motor_output, hidden, mu, logsig = rnn(visual_input[i].view(-1,1,64,64)*(1-pred_fb) + visual_output*pred_fb, motor_input[i].view(-1,16)*(1-pred_fb) + motor_output.view(-1,16)*pred_fb, hidden) motor_loss = criterion(motor_output.view(-1,16), motor_target[i].view(-1,16)) visual_loss = criterion(visual_output, visual_target[i].view(-1,1,64,64)) kl_loss = rnn.KL(mu, logsig) total_loss_ = visual_loss.item() + motor_loss.item() + kl_loss.item() visual_mult = total_loss_ / 3 / visual_loss.item() motor_mult = total_loss_ / 3 / motor_loss.item() Dkl_mult = total_loss_ / 3 / kl_loss.item() loss += visual_loss*visual_mult + motor_loss*motor_mult + kl_loss*Dkl_mult loss.backward() writer.add_scalar('motor loss', motor_loss, epoch) writer.add_scalar('visual loss', visual_loss, epoch) VLOSS.append(visual_loss.item()) MLOSS.append(motor_loss.item()) rnn.optimizer.step() printProgressBar(epoch + 1, EPOCHS, prefix='Epoch: {} vloss: {:.2f} mloss: {:.6f} Dkl: {:.3f}'.format(epoch, visual_loss.item(), motor_loss.item(), kl_loss.item()), suffix='Complete', length=25) if epoch % 1000 == 0: torch.save(rnn.state_dict(), 'checkpoint') except KeyboardInterrupt: print('\nKeyboard Interrupt') break
def train(self) -> bool: """Run training in a separate thread (added to the global application ThreadPool).""" # Free memory on the GPU self._clear_session() # Check that the data is set properly if len(self._train_image_names) == 0 or \ len(self._train_mask_names) == 0 or \ len(self._validation_image_names) == 0 or \ len(self._validation_mask_names) == 0: self._message = "No training/validation data found." return False if len(self._train_image_names) != len(self._train_mask_names) == 0: self._message = "The number of training images does not match the number of training masks." return False if len(self._validation_image_names) != len( self._validation_mask_names) == 0: self._message = "The number of validation images does not match the number of validation masks." return False # Define the transforms self._define_transforms() # Define the datasets and data loaders self._define_training_data_loaders() # Instantiate the model self._define_model() # Define the loss function self._define_training_loss() # Define the optimizer (with default parameters) self._define_optimizer() # Define the validation metric self._define_validation_metric() # Define experiment name and model name experiment_name, model_file_name = self._prepare_experiment_and_model_names( ) # Keep track of the best model file name self._best_model = model_file_name # Enter the main training loop best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() # Initialize TensorBoard's SummaryWriter writer = SummaryWriter(experiment_name) for epoch in range(self._n_epochs): # Inform self._print_header(f"Epoch {epoch + 1}/{self._n_epochs}") # Switch to training mode self._model.train() epoch_loss = 0 step = 0 for batch_data in self._train_dataloader: # Update step step += 1 # Get the next batch and move it to device inputs, labels = batch_data[0].to( self._device), batch_data[1].to(self._device) # Zero the gradient buffers self._optimizer.zero_grad() # Forward pass outputs = self._model(inputs) # Calculate the loss loss = self._training_loss_function(outputs, labels) # Back-propagate loss.backward() # Update weights (optimize) self._optimizer.step() # Update and store metrics epoch_loss += loss.item() epoch_len = len( self._train_dataset) / self._train_dataloader.batch_size if epoch_len != int(epoch_len): epoch_len = int(epoch_len) + 1 print( f"Batch {step}/{epoch_len}: train_loss = {loss.item():.4f}", file=self._stdout) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"Average loss = {epoch_loss:.4f}", file=self._stdout) writer.add_scalar("average_train_loss", epoch_loss, epoch + 1) # Validation if (epoch + 1) % self._validation_step == 0: self._print_header("Validation") # Switch to evaluation mode self._model.eval() # Make sure not to update the gradients with torch.no_grad(): # Global metrics metric_sum = 0.0 metric_count = 0 metric = 0.0 # Keep track of the metrics for all classes metric_sum_classes = self._out_channels * [0.0] metric_count_classes = self._out_channels * [0] metric_classes = self._out_channels * [0.0] for val_data in self._validation_dataloader: # Get the next batch and move it to device val_images, val_labels = val_data[0].to( self._device), val_data[1].to(self._device) # Apply sliding inference over ROI size val_outputs = sliding_window_inference( val_images, self._roi_size, self._sliding_window_batch_size, self._model) val_outputs = self._validation_post_transforms( val_outputs) # Compute overall metric value, not_nans = self._validation_metric( y_pred=val_outputs, y=val_labels) not_nans = not_nans.item() metric_count += not_nans metric_sum += value.item() * not_nans # Compute metric for each class for c in range(self._out_channels): value_obj, not_nans = self._validation_metric( y_pred=val_outputs[:, c:c + 1], y=val_labels[:, c:c + 1]) not_nans = not_nans.item() metric_count_classes[c] += not_nans metric_sum_classes[c] += value_obj.item( ) * not_nans # Global metric metric = metric_sum / metric_count metric_values.append(metric) # Metric per class for c in range(self._out_channels): metric_classes[c] = metric_sum_classes[ c] / metric_count_classes[c] # Print summary print(f"Global metric = {metric:.4f} ", file=self._stdout) for c in range(self._out_channels): print( f"Class '{self._class_names[c]}' metric = {metric_classes[c]:.4f} ", file=self._stdout) # Do we have the best metric so far? if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(self._model.state_dict(), model_file_name) print( f"New best global metric = {best_metric:.4f} at epoch: {best_metric_epoch}", file=self._stdout) print( f"Saved best model '{Path(model_file_name).name}'", file=self._stdout) # Add validation loss and metrics to log writer.add_scalar("val_mean_dice_loss", metric, epoch + 1) for c in range(self._out_channels): metric_name = f"val_{self._class_names[c].lower()}_metric" writer.add_scalar(metric_name, metric_classes[c], epoch + 1) print( f"Training completed. Best_metric = {best_metric:.4f} at epoch: {best_metric_epoch}", file=self._stdout) writer.close() # Return success return True
if param.requires_grad: writer.add_histogram('Model/{}'.format(name), param, epoch) # Average the train / validation metrics train_loss = torch.mean(torch.tensor(epoch_train_loss)) train_acc = torch.mean(torch.tensor(epoch_train_acc)) val_loss = torch.mean(torch.tensor(epoch_val_loss)) val_acc = torch.mean(torch.tensor(epoch_val_acc)) if ((epoch + 1) % 100 == 0): print(f'\nEpoch {epoch} - Saving model...') model_path = os.path.join(logdir, 'model', 'state_dict.pth') opt_path = os.path.join(logdir, 'model', 'opt_state.pth') torch.save(net.state_dict(), model_path) torch.save(optimizer.state_dict(), opt_path) print('Done.') # --------------------------------------------------------------------------- # Log metrics writer.add_scalar('Train/Accuracy', train_acc, epoch) writer.add_scalar('Train/Loss', train_loss, epoch) writer.add_scalar('Validation/Accuracy', val_acc, epoch) writer.add_scalar('Validation/Loss', val_loss, epoch) writer.add_scalar('Learning Rate', scheduler.get_last_lr()[0], epoch) # --------------------------------------------------------------------------- print('\n[{:03d}/{:03d}] LOSS-------------- ACCURACY'.\ format(epoch, config['NUM_EPOCHS'])) print('[TRAIN] {} {} %'.format(train_loss, train_acc)) print('[VAL] {} {} %'.format(val_loss, val_acc)) # Adjust learning rate scheduler.step()
def train_classifier(model, data_loaders, args): """Train an emotion classifier.""" # Setup device = args.device optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) model, optimizer, _, start_epoch, is_trained = load_from_ckpnt( args.classifier_ckpnt, model, optimizer ) scheduler = MultiStepLR(optimizer, [3, 6, 9], gamma=0.3, last_epoch=start_epoch - 1) if is_trained: return model writer = SummaryWriter('runs/' + args.checkpoint.replace('.pt', '')) best_acc = -1 # Training loop for epoch in range(start_epoch, args.epochs): print("Epoch: %d/%d" % (epoch + 1, args.epochs)) kbar = pkbar.Kbar(target=len(data_loaders['train']), width=25) model.train() #model.enable_grads() for step, ex in enumerate(data_loaders['train']): images, _, emotions, _ = ex logits = model(images.to(device)) labels = emotions.to(device) loss = F.binary_cross_entropy_with_logits(logits, labels) kbar.update(step, [("loss", loss)]) optimizer.zero_grad() loss.backward() optimizer.step() writer.add_scalar( 'loss', loss.item(), epoch * len(data_loaders['train']) + step ) break writer.add_scalar( 'lr', optimizer.state_dict()['param_groups'][0]['lr'], epoch ) # Evaluation and model storing if epoch % 2 == 0: print("\nValidation") acc = eval_classifier(model, data_loaders['test'], args, writer, epoch=epoch) writer.add_scalar('mAP', acc, epoch) if acc >= best_acc: torch.save( { "epoch": epoch + 1, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict() }, args.classifier_ckpnt ) best_acc = acc else: # load checkpoint to update epoch checkpoint = torch.load(args.classifier_ckpnt) checkpoint["epoch"] += 1 torch.save(checkpoint, args.classifier_ckpnt) scheduler.step() # Test test_acc = eval_classifier(model, data_loaders['test'], args, writer) print(f"Test Accuracy: {test_acc}") return model
"Unbiased Accuracy : [f : {:.2f}, g : {:.2f}, v : {:.2f}, b : {:.2f}]" .format( 100 * f_acc_b / (len(biased_test_loader) * args.batch_size), 100 * g_acc_b / (len(biased_test_loader) * args.batch_size), 100 * v_acc_b / (len(biased_test_loader) * args.batch_size), 100 * b_acc_b / (len(biased_test_loader) * args.batch_size), 100 * f_acc_d / (len(unbiased_test_loader) * args.batch_size), 100 * g_acc_d / (len(unbiased_test_loader) * args.batch_size), 100 * v_acc_d / (len(unbiased_test_loader) * args.batch_size), 100 * b_acc_d / (len(unbiased_test_loader) * args.batch_size))) writer.add_scalar("loss/f", loss_f, epoch) writer.add_scalar("loss/g", loss_g, epoch) writer.add_scalar("loss/v", loss_v, epoch) writer.add_scalar("loss/b", loss_b, epoch) writer.add_scalar("loss/hsic", criterionHSIC(f_feats, g_feats), epoch) writer.add_scalar( "accuracy/biased/f", 100 * f_acc_b / (len(biased_test_loader) * args.batch_size), epoch) writer.add_scalar( "accuracy/biased/g", 100 * g_acc_b / (len(biased_test_loader) * args.batch_size), epoch) writer.add_scalar( "accuracy/biased/v",
class Trainer: def __init__(self, config): self.config = config self.config['trainer']['output_dir'] = os.path.join( str(pathlib2.Path(os.path.abspath(__name__)).parent), self.config['trainer']['output_dir']) self.data_cfg = self.config["data_cfg"] self.dataset_name = self.data_cfg['name'] self.method_name = "{0}_{1}".format(self.config['arch']['backbone'], self.dataset_name) self.save_dir = os.path.join(self.config['trainer']['output_dir'], self.method_name) self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint') if self.config['trainer']['resume_checkpoint'] == '' and self.config[ 'trainer']['finetune_checkpoint'] == '': shutil.rmtree(self.save_dir, ignore_errors=True) if not os.path.exists(self.checkpoint_dir): os.makedirs(self.checkpoint_dir) self.global_step = 0 self.start_epoch = 1 self.tensorboard_enable = self.config['trainer']['tensorboard'] self.epochs = self.config['trainer']['epochs'] self.save_interval = self.config['trainer']['save_interval'] self.show_images_interval = self.config['trainer'][ 'show_images_interval'] self.display_interval = self.config['trainer']['display_interval'] if self.tensorboard_enable: from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter(self.save_dir) # setup logger self.logger = setup_logger(os.path.join(self.save_dir, 'train_log')) self.logger.info(pformat(self.config)) # device torch.manual_seed(self.config['trainer']['seed']) # 为CPU设置随机种子 if len(self.config['trainer']['gpus']) > 0 and torch.cuda.is_available( ): self.with_cuda = True torch.backends.cudnn.benchmark = True self.logger.info('Train with gpu {} & PyTorch {}'.format( self.config['trainer']['gpus'], torch.__version__)) self.gpus = { i: item for i, item in enumerate(self.config['trainer']['gpus']) } self.device = torch.device("cuda:0") torch.cuda.manual_seed( self.config['trainer']['seed']) # 为当前GPU设置随机种子 torch.cuda.manual_seed_all( self.config['trainer']['seed']) # 为所有GPU设置随机种子 else: self.with_cuda = False self.logger.info('Train with cpu & PyTorch {}'.format( torch.__version__)) self.device = torch.device("cpu") self.logger.info('Device {}'.format(self.device)) # train data loader self.logger.info('Loading train data...') self.train_data_len = len(os.listdir(self.data_cfg["train_img_path"])) self.train_set = CustomDataSetRBox( self.data_cfg, max_img_length=self.config["trainer"]["input_size"], long_size=self.config['trainer']['long_size']) self.embedding_size = self.train_set.embedding_size self.words_embeddings = self.train_set.words_embeddings self.train_loader = data.DataLoader( self.train_set, batch_size=self.config["trainer"]["batch_size"], shuffle=True, num_workers=self.config["trainer"]["num_workers"], drop_last=False, pin_memory=True) self.train_loader_len = len(self.train_loader) self.logger.info('Train data has {0} samples, {1} in loader'.format( self.train_data_len, self.train_loader_len)) # test data loader self.test_gt_path = self.train_set.test_gt_path self.test_img_files = self.train_set.test_img_files self.test_gt_files = self.train_set.test_gt_files self.test_words = self.train_set.test_words self.train_unique_words = self.train_set.train_unique_words self.label_encoder = LabelEncoder() # model self.logger.info('Loading model...') self.model = WordRetrievalModel( n_out=self.embedding_size, backbone=self.config["arch"]["backbone"], pre_trained=self.config["arch"]["pre_trained"]) # loss function self.logger.info('Loading loss function...') self.criterion = ModelLoss( weight_cls=self.config["loss"]["weight_cls"], weight_angle=self.config["loss"]["weight_angle"], weight_diou=self.config["loss"]["weight_diou"], weight_embed=self.config["loss"]["weight_embed"]) # optimizer and lr_scheduler self.logger.info('Loading optimizer and lr_scheduler...') self.lr = self.config["optimizer"]['args']['lr'] self.lr_step = self.config["trainer"]["lr_step"] self.optimizer = self._initialize('optimizer', torch.optim, self.model.parameters()) self.scheduler = self._initialize('lr_scheduler', torch.optim.lr_scheduler, self.optimizer) if self.config['trainer']['resume_checkpoint'] != '': self._load_checkpoint(self.config['trainer']['resume_checkpoint'], resume=True) elif self.config['trainer']['finetune_checkpoint'] != '': self._load_checkpoint( self.config['trainer']['finetune_checkpoint'], resume=False) # eval args self.cls_score_thresh = self.config['tester']['cls_score_thresh'] self.bbox_nms_overlap = self.config['tester']['bbox_nms_overlap'] self.query_nms_overlap = self.config['tester']['query_nms_overlap'] self.overlap_thresh = 0.25 self.metric = self.config['tester']['distance_metric'] # 单机多卡 num_gpus = torch.cuda.device_count() if num_gpus > 1: self.model = nn.DataParallel(self.model) self.model.to(self.device) self.metrics = { 'precision': 0, 'recall': 0, 'hmean': 0, 'map': 0, 'mr': 0, 'train_loss': float('inf'), 'best_model': '' } def train(self): """ Full training logic """ self.logger.info('Start training...') for epoch in range(self.start_epoch, self.epochs + 1): try: self.adjust_learning_rate(epoch) self.epoch_result = self._train_epoch(epoch) self._on_epoch_finish(epoch) except torch.cuda.CudaError: self._log_memory_usage() if self.tensorboard_enable: self.writer.close() self._on_train_finish() def _train_epoch(self, epoch): """ Training logic for an epoch """ self.model.train() epoch_start, batch_start = time.time(), time.time() train_loss = 0.0 lr = self.optimizer.param_groups[0]['lr'] for i, (img, gt_score, gt_geo, ignored_map, gt_embedding) in enumerate(self.train_loader): if i >= self.train_loader_len: break self.global_step += 1 lr = self.optimizer.param_groups[0]['lr'] cur_batch_size = img.size()[0] img, gt_score, gt_geo, ignored_map, gt_embedding = img.to(self.device), gt_score.to(self.device), \ gt_geo.to(self.device), ignored_map.to(self.device), \ gt_embedding.to(self.device) (predict_score, predict_geo), predict_embedding = self.model(img) loss_all, loss_cls, loss_ang, loss_diou, loss_embed = self.criterion( gt_score, predict_score, gt_geo, predict_geo, gt_embedding, predict_embedding, ignored_map) # backward self.optimizer.zero_grad() loss_all.backward() self.optimizer.step() loss_all = loss_all.item() loss_cls, loss_ang, loss_diou = loss_cls.item(), loss_ang.item( ), loss_diou.item() loss_embed = loss_embed.item() train_loss += loss_all if i % self.display_interval == 0 or i == self.train_loader_len - 1: batch_time = time.time() - batch_start self.logger.info( '[{}/{}], [{}/{}], g_step: {}, Spe: {:.1f} sam/sec, l_all: {:.4f}, l_cls: {:.4f}, ' 'l_ang: {:.4f}, l_diou: {:.4f}, l_embed: {:.4f}, lr: {:.6}, T: {:.2f}' .format( str(epoch).zfill(3), self.epochs, str(i + 1).zfill(3), self.train_loader_len, self.global_step, self.display_interval * cur_batch_size / batch_time, loss_all, loss_cls, loss_ang, loss_diou, loss_embed, lr, batch_time)) batch_start = time.time() if self.tensorboard_enable: self.writer.add_scalar('TRAIN/LOSS/loss_all', loss_all, self.global_step) self.writer.add_scalar('TRAIN/LOSS/loss_cls', loss_cls, self.global_step) self.writer.add_scalar('TRAIN/LOSS/loss_ang', loss_ang, self.global_step) self.writer.add_scalar('TRAIN/LOSS/loss_diou', loss_diou, self.global_step) self.writer.add_scalar('TRAIN/LOSS/loss_embed', loss_embed, self.global_step) self.writer.add_scalar('TRAIN/lr', lr, self.global_step) return { 'train_loss': train_loss / self.train_loader_len, 'lr': lr, 'time': time.time() - epoch_start, 'epoch': epoch } def _eval_map(self): self.logger.info('Enter evaluating...') self.model.eval() result_save_path = os.path.join(self.save_dir, 'result') if os.path.exists(result_save_path): shutil.rmtree(result_save_path, ignore_errors=True) if not os.path.exists(result_save_path): os.makedirs(result_save_path) predict_embeddings, joint_boxes, all_gt_boxes = [], [], [] qbs_words, qbs_queries, qbs_targets, db_targets, gt_targets = [], [], [], [], [] overlaps, used_test_word = [], [] # Compute a mapping from class string to class id... self.label_encoder.fit([word for word in self.test_words]) # Create queries... test_unique_words, counts = np.unique(self.test_words, return_counts=True) for idx, test_word in enumerate(self.test_words): gt_targets.extend(self.label_encoder.transform([test_word])) if test_word not in used_test_word and test_word in test_unique_words: qbs_words.append(test_word) qbs_queries.append(self.words_embeddings[test_word]) qbs_targets.extend(self.label_encoder.transform([test_word])) used_test_word.append(test_word) for i, (img_file, gt_file) in enumerate( zip(self.test_img_files, self.test_gt_files)): self.logger.info('Evaluating {} image: {}'.format(i, img_file)) # Get test gt boxes & gt words... gt_boxes, gt_words = [], [] with open(gt_file, mode='r', encoding='utf-8') as f: lines = f.readlines() for line in lines: line = line.strip().rstrip('\n').lstrip( '\ufeff').strip().split(',', maxsplit=8) gt_boxes.append([int(ver) for ver in line[:8]]) gt_words.append(str(line[-1]).strip().lower()) # Get img... im = Image.open(img_file) im = im.convert("RGB") im, ratio_w, ratio_h = resize_img( im, long_size=self.config['trainer']['long_size']) with torch.no_grad(): if str(self.device).__contains__('cuda'): torch.cuda.synchronize(self.device) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) im = transform(im).unsqueeze(0) im = im.to(self.device) (predict_score, predict_geo), predict_embed = self.model(im) if str(self.device).__contains__('cuda'): torch.cuda.synchronize(self.device) # Predicting boxes... predict_boxes, _ = get_boxes( score=predict_score.squeeze(0).cpu().numpy(), geo=predict_geo.squeeze(0).cpu().numpy(), cls_score_thresh=self.cls_score_thresh, bbox_nms_overlap=self.bbox_nms_overlap) predict_embed = predict_embed.squeeze(0).cpu().numpy() if predict_boxes is None: continue self.logger.info( 'Idx: {0} ===> Predict result [predict_boxes: {1}; gt_boxes: {2}]' .format(i, predict_boxes.shape, len(gt_boxes))) for predict_box in predict_boxes: min_x = min(predict_box[0], predict_box[2], predict_box[4], predict_box[6]) max_x = max(predict_box[0], predict_box[2], predict_box[4], predict_box[6]) min_y = min(predict_box[1], predict_box[3], predict_box[5], predict_box[7]) max_y = max(predict_box[1], predict_box[3], predict_box[5], predict_box[7]) w, h = max_x - min_x, max_y - min_y differ = h * 0.2 if h < w else w * 0.2 min_x, max_x = int((min_x + differ) / 4), int( (max_x - differ) / 4) min_y, max_y = int((min_y + differ) / 4), int( (max_y - differ) / 4) if min_x > max_x or min_y > max_y: continue predict_embeddings.append( np.mean(predict_embed[:, min_y:max_y, min_x:max_x], axis=(1, 2))) predict_boxes = adjust_ratio(predict_boxes, ratio_w, ratio_h) seq = [] if predict_boxes is not None: seq.extend([ ','.join([str(int(b)) for b in box[:-1]]) + '\n' for box in predict_boxes ]) with open( os.path.join( result_save_path, str(os.path.basename(img_file).split('.')[0]) + '.txt'), 'w') as f: f.writelines(seq) joint_boxes.extend(predict_boxes[:, :8]) all_gt_boxes.extend(gt_boxes) gt_boxes = np.array(gt_boxes) # Calculate overlap... overlap = cal_overlap(predict_boxes, gt_boxes) overlaps.append(overlap) inds = overlap.argmax(axis=1) db_targets.extend( self.label_encoder.transform([gt_words[idx] for idx in inds])) # End evaluate... db = np.vstack(predict_embeddings) if len( predict_embeddings) != 0 else np.array(predict_embeddings) all_overlaps = np.zeros((len(joint_boxes), len(all_gt_boxes)), dtype=np.float32) x, y = 0, 0 for o in overlaps: all_overlaps[y:y + o.shape[0], x:x + o.shape[1]] = o y += o.shape[0] x += o.shape[1] db_targets, qbs_targets, qbs_words = np.array(db_targets), np.array( qbs_targets), np.array(qbs_words) qbs_queries, joint_boxes = np.array(qbs_queries), np.array(joint_boxes) assert (qbs_queries.shape[0] == qbs_targets.shape[0]) assert (db.shape[0] == db_targets.shape[0]) self.logger.info('Calculate mAP...') mAP_qbs, mR_qbs = cal_map(qbs_queries, qbs_targets, db, db_targets, gt_targets, joint_boxes, all_overlaps, self.query_nms_overlap, self.overlap_thresh, qbs_words, num_workers=0) mAP_qbs, mR_qbs = np.mean(mAP_qbs * 100), np.mean(mR_qbs * 100) # Calculate recall precision f1 res_dict = cal_recall_precison_f1(gt_path=self.test_gt_path, result_path=result_save_path) return res_dict['recall'], res_dict['precision'], res_dict[ 'hmean'], mAP_qbs, mR_qbs def _on_epoch_finish(self, epoch): # torch.cuda.empty_cache() self.logger.info( '[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'.format( self.epoch_result['epoch'], self.epochs, self.epoch_result['train_loss'], self.epoch_result['time'], self.epoch_result['lr'])) if epoch % self.save_interval == 0: net_save_path = '{}/WordRetrievalNet_latest.pth'.format( self.checkpoint_dir) save_best = False if self.config['trainer']['metrics'] == 'map': # 使用map作为最优模型指标 recall, precision, hmean, mAP_qbs, mR_qbs = self._eval_map() if self.tensorboard_enable: self.writer.add_scalar('EVAL/precision', precision, self.global_step) self.writer.add_scalar('EVAL/recall', recall, self.global_step) self.writer.add_scalar('EVAL/hmean', hmean, self.global_step) self.writer.add_scalar('EVAL/mAP', mAP_qbs, self.global_step) self.writer.add_scalar('EVAL/mR', mR_qbs, self.global_step) self.logger.info( 'test: precision: {:.6f}, recall: {:.6f}, f1: {:.6f}, map: {:.2f}, mr: {:.2f}' .format(precision, recall, hmean, mAP_qbs, mR_qbs)) if mAP_qbs > self.metrics['map']: save_best = True self.metrics['train_loss'], self.metrics[ 'best_model'] = self.epoch_result[ 'train_loss'], net_save_path self.metrics['precision'], self.metrics[ 'recall'], self.metrics[ 'hmean'] = precision, recall, hmean self.metrics['map'], self.metrics['mr'] = mAP_qbs, mR_qbs else: if self.epoch_result['train_loss'] < self.metrics['train_loss']: save_best = True self.metrics['train_loss'], self.metrics[ 'best_model'] = self.epoch_result[ 'train_loss'], net_save_path self._save_checkpoint(self.epoch_result['epoch'], net_save_path, save_best) def _on_train_finish(self): for k, v in self.metrics.items(): self.logger.info('{}:{}'.format(k, v)) self.logger.info('Finish train.') def _log_memory_usage(self): if not self.with_cuda: return usage = [] for deviceID, device in self.gpus.items(): allocated = torch.cuda.memory_allocated( int(deviceID)) / (1024 * 1024) cached = torch.cuda.memory_cached(int(deviceID)) / (1024 * 1024) usage.append( ' CUDA: {0}; Allocated: {1} MB; Cached: {2} MB \n'.format( device, allocated, cached)) self.logger.debug("Memory Usage: \n{}".format(''.join(usage))) def _save_checkpoint(self, epoch, file_name, save_best=False): """ Saving checkpoints """ state_dict = { 'epoch': epoch, 'global_step': self.global_step, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict(), 'config': self.config, 'metrics': self.metrics, } filename = os.path.join(self.checkpoint_dir, file_name) torch.save(state_dict, filename) if save_best: shutil.copy( filename, os.path.join(self.checkpoint_dir, 'WordRetrievalNet_best.pth')) self.logger.info("Saving current best: {}".format(file_name)) else: self.logger.info("Saving checkpoint: {}".format(filename)) def _load_checkpoint(self, checkpoint_path, resume): """ Resume from saved checkpoints """ self.logger.info("Loading checkpoint: {} ...".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) self.model.load_state_dict({ k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items() }) if resume: self.global_step = checkpoint['global_step'] self.start_epoch = checkpoint['epoch'] + 1 self.config['lr_scheduler']['args'][ 'last_epoch'] = self.start_epoch self.optimizer.load_state_dict(checkpoint['optimizer']) self.scheduler.load_state_dict(checkpoint['scheduler']) if 'metrics' in checkpoint: self.metrics = checkpoint['metrics'] if self.with_cuda: for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(self.device) self.logger.info("Resume from checkpoint {} (epoch {})".format( checkpoint_path, self.start_epoch)) else: self.logger.info( "FineTune from checkpoint {}".format(checkpoint_path)) def _initialize(self, name, module, *args, **kwargs): module_name = self.config[name]['type'] module_args = self.config[name]['args'] assert all([ k not in module_args for k in kwargs ]), 'Overwriting kwargs given in config file is not allowed' module_args.update(kwargs) return getattr(module, module_name)(*args, **module_args) def adjust_learning_rate(self, epoch): if epoch in self.lr_step: self.lr = self.lr * 0.1 for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr
def train(opt): ################################ # Build dataloader ################################ loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length ########################## # Initialize infos ########################## infos = { 'iter': 0, 'epoch': 0, 'loader_state_dict': None, 'vocab': loader.get_vocab(), } # Load old infos(if there is) and check if models are compatible if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')): with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'), 'rb') as f: infos = utils.pickle_load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert getattr(saved_model_opt, checkme) == getattr( opt, checkme ), "Command line argument and saved model disagree on '%s' " % checkme infos['opt'] = opt ######################### # Build logger ######################### # naive dict logger histories = defaultdict(dict) if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f: histories.update(utils.pickle_load(f)) # tensorboard logger tb_summary_writer = SummaryWriter(opt.checkpoint_path) ########################## # Build model ########################## opt.vocab = loader.get_vocab() model = models.setup(opt).cuda() del opt.vocab # Load pretrained weights: if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, 'model.pth')): model.load_state_dict( torch.load(os.path.join(opt.start_from, 'model.pth'))) # Wrap generation model with loss function(used for training) # This allows loss function computed separately on each machine lw_model = LossWrapper(model, opt) # Wrap with dataparallel dp_model = torch.nn.DataParallel(model) dp_lw_model = torch.nn.DataParallel(lw_model) ########################## # Build optimizer ########################## if opt.noamopt: assert opt.caption_model in [ 'transformer', 'bert', 'm2transformer' ], 'noamopt can only work with transformer' optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) elif opt.reduce_on_plateau: optimizer = utils.build_optimizer(model.parameters(), opt) optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) ######################### # Get ready to start ######################### iteration = infos['iter'] epoch = infos['epoch'] # For back compatibility if 'iterators' in infos: infos['loader_state_dict'] = { split: { 'index_list': infos['split_ix'][split], 'iter_counter': infos['iterators'][split] } for split in ['train', 'val', 'test'] } loader.load_state_dict(infos['loader_state_dict']) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) if opt.noamopt: optimizer._step = iteration # flag indicating finish of an epoch # Always set to True at the beginning to initialize the lr or etc. epoch_done = True # Assure in training mode dp_lw_model.train() # Start training try: while True: # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break if epoch_done: if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # set the decayed rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min( opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False # If start structure loss training if opt.structure_after != -1 and epoch >= opt.structure_after: struc_flag = True init_scorer(opt.cached_tokens) else: struc_flag = False epoch_done = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else _.cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp optimizer.zero_grad() model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag) loss = model_out['loss'].mean() loss.backward() if opt.grad_clip_value != 0: getattr(torch.nn.utils, 'clip_grad_%s_' % (opt.grad_clip_mode))(model.parameters(), opt.grad_clip_value) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() end = time.time() if struc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start)) elif not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, model_out['reward'].mean(), end - start)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 epoch_done = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): tb_summary_writer.add_scalar('train_loss', train_loss, iteration) if opt.noamopt: opt.current_lr = optimizer.rate() elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr tb_summary_writer.add_scalar('learning_rate', opt.current_lr, iteration) tb_summary_writer.add_scalar('scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: tb_summary_writer.add_scalar('avg_reward', model_out['reward'].mean(), iteration) elif struc_flag: tb_summary_writer.add_scalar( 'lm_loss', model_out['lm_loss'].mean().item(), iteration) tb_summary_writer.add_scalar( 'struc_loss', model_out['struc_loss'].mean().item(), iteration) tb_summary_writer.add_scalar( 'reward', model_out['reward'].mean().item(), iteration) tb_summary_writer.add_scalar( 'reward_var', model_out['reward'].var(1).mean(), iteration) histories['loss_history'][ iteration] = train_loss if not sc_flag else model_out[ 'reward'].mean() histories['lr_history'][iteration] = opt.current_lr histories['ss_prob_history'][iteration] = model.ss_prob # update infos infos['iter'] = iteration infos['epoch'] = epoch infos['loader_state_dict'] = loader.state_dict() # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0 and not opt.save_every_epoch) or \ (epoch_done and opt.save_every_epoch): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, lw_model.crit, loader, eval_kwargs) if opt.reduce_on_plateau: if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss) # Write validation result into summary tb_summary_writer.add_scalar('validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): tb_summary_writer.add_scalar(k, v, iteration) histories['val_result_history'][iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True # Dump miscalleous informations infos['best_val_score'] = best_val_score utils.save_checkpoint(opt, model, infos, optimizer, histories) if opt.save_history_ckpt: utils.save_checkpoint( opt, model, infos, optimizer, append=str(epoch) if opt.save_every_epoch else str(iteration)) if best_flag: utils.save_checkpoint(opt, model, infos, optimizer, append='best') except (RuntimeError, KeyboardInterrupt): print('Save ckpt on exception ...') utils.save_checkpoint(opt, model, infos, optimizer) print('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace)
def train_mine_policy(scenario: Scenario, horizon: int, batch_size: int, epochs: int, ntrvs: int, mine_class: nn.Module, mine_params, q_net: nn.Module, pi_net: nn.Module, tradeoff: float, lr: float, tag: str = None, save_every: int = 100, log_video_every: Union[int, None] = None, minibatch_size=0, opt_iters=1, lowest_mi=np.inf, cutoff=np.inf, device=pt.device('cpu')): q_net.to(device=device) pi_net.to(device=device) opt = pt.optim.Adam(list(pi_net.parameters()) + list(q_net.parameters()), lr=lr) mine = [mine_class().to(device=device) for t in range(horizon)] last_time = time.time() mi = pt.zeros(horizon).to(device=device) scenario.device = pt.device('cpu') prev_best_value = np.inf current_value = np.inf if minibatch_size == 0: minibatch_size = batch_size if tag is not None: writer = SummaryWriter(f'runs/{tag}', flush_secs=1) for epoch in range(epochs): #if epoch % save_every == 0 or epoch == epochs - 1: start_epoch_event = pt.cuda.Event(enable_timing=True) end_epoch_event = pt.cuda.Event(enable_timing=True) end_rollout_event = pt.cuda.Event(enable_timing=True) start_epoch_event.record() pi_log_probs = pt.zeros((horizon, minibatch_size), device=device) q_log_probs = pt.zeros((horizon, minibatch_size), device=device) q_net.cpu() pi_net.cpu() states, outputs, samples, trvs, inputs, costs = rollout( pi_net, q_net, ntrvs, scenario, horizon, batch_size, pt.device('cpu')) end_rollout_event.record() pt.cuda.synchronize() elapsed_rollout_time = start_epoch_event.elapsed_time( end_rollout_event) / 1000 print(f'Rollout Time: {elapsed_rollout_time:.3f}') print( f'Mean Abs. Displacement: {pt.abs(states[0, -1, :] - states[1, -1, :]).mean().detach().item()}' ) states = states.to(device) outputs = outputs.to(device) samples = samples.to(device) trvs = trvs.to(device) inputs = inputs.to(device) costs = costs.to(device) q_net.to(device) pi_net.to(device) for s in range(batch_size): trv = pt.zeros(ntrvs, device=device) for t in range(horizon): trvs[:, t, s] = q_net(outputs[:, t, s], trv, t, samples[:, t, s])[0] trv = trvs[:, t, s] value = costs.sum(axis=0).mean().item() if tradeoff > -1: states_mi = states.detach().cuda() trvs_mi = trvs.detach().cuda() for t in range(horizon): mine[t].cuda() if epoch == 0: values = train_mine_network( mine[t], (states_mi[:, t, :], trvs_mi[:, t, :]), epochs=100 * mine_params['epochs']) else: train_mine_network(mine[t], (states_mi[:, t, :], trvs_mi[:, t, :]), epochs=mine_params['epochs']) for t in range(horizon): num_datapts = states.shape[2] batch_size = num_datapts joint_batch_idx = np.random.choice(range(num_datapts), size=num_datapts, replace=False) marginal_batch_idx1 = np.random.choice(range(num_datapts), size=num_datapts, replace=False) marginal_batch_idx2 = np.random.choice(range(num_datapts), size=num_datapts, replace=False) joint_batch = pt.cat( (states[:, t, joint_batch_idx], trvs[:, t, joint_batch_idx]), axis=0).t() marginal_batch = pt.cat((states[:, t, marginal_batch_idx1], trvs[:, t, marginal_batch_idx2]), axis=0).t() j_T = mine[t](joint_batch) m_T = mine[t](marginal_batch) mi[t] = j_T.mean() - pt.log(pt.mean(pt.exp(m_T))) mi_sum = mi.sum() baseline = costs.sum(axis=0).mean() current_value = value + tradeoff * mi_sum.detach() if value < cutoff and mi_sum < lowest_mi: print('Saving Model...') lowest_mi = mi_sum.item() pt.save( { 'pi_net_state_dict': pi_net.state_dict(), 'q_net_state_dict': q_net.state_dict() }, f'models/{tag}_epoch_{epoch}_mi_{lowest_mi:.3f}') else: print(f'Current Best: {prev_best_value}') for iter in range(opt_iters): print(f'Computing Iteration {iter}') minibatch_idx = np.random.choice(range(batch_size), size=minibatch_size, replace=False) outputs_minibatch = outputs[:, :, minibatch_idx] trvs_minibatch = trvs[:, :, minibatch_idx] inputs_minibatch = inputs[:, :, minibatch_idx] costs_minibatch = costs[:, minibatch_idx] for s in range(minibatch_size): trv = pt.zeros(ntrvs, device=device) for t in range(horizon): q_log_probs[t, s] = q_net.log_prob(trvs[:, t, s].detach(), outputs_minibatch[:, t, s], trv.detach(), t) pi_log_probs[t, s] = pi_net.log_prob( inputs_minibatch[:, t, s].detach(), trvs_minibatch[:, t, s].detach(), t) trv = trvs_minibatch[:, t, s] opt.zero_grad() loss = pt.mul(pi_log_probs.sum(axis=0), costs_minibatch.sum(axis=0) - baseline).mean() + \ pt.mul(q_log_probs.sum(axis=0), costs_minibatch.sum(axis=0) - baseline).mean() + \ tradeoff * mi_sum loss.backward() opt.step() pi_log_probs = pi_log_probs.detach() q_log_probs = pi_log_probs.detach() if tag is not None: writer.add_scalar('Loss/Total', value + tradeoff * mi.sum().item(), epoch) writer.add_scalar('Loss/MI', mi_sum, epoch) writer.add_scalar('Loss/Cost', value, epoch) writer.add_histogram('Loss/Cost Dist', costs.sum(axis=0), epoch) if log_video_every is not None and epoch % log_video_every == 0: print('Saving Video...') best_traj_idx = pt.argmin(costs.sum(axis=0)) worst_traj_idx = pt.argmax(costs.sum(axis=0)) best_traj_vid = pt.stack([ pt.stack([ outputs[:, t, best_traj_idx].view(3, 64, 64) for t in range(horizon) ]) ]) worst_traj_vid = pt.stack([ pt.stack([ outputs[:, t, worst_traj_idx].view(3, 64, 64) for t in range(horizon) ]) ]) writer.add_video('Loss/Worst Traj', worst_traj_vid, epoch) writer.add_video('Loss/Best Traj', best_traj_vid, epoch) mi = mi.detach() end_epoch_event.record() pt.cuda.synchronize() elapsed_epoch_time = start_epoch_event.elapsed_time( end_epoch_event) / 1000 print( f'[{tradeoff}.{epoch}: {elapsed_epoch_time:.3f}]\t\tAvg. Cost: {value:.3f}\t\tEst. MI: {mi_sum.item():.5f}\t\tTotal: {value + tradeoff * mi_sum.item():.3f}\t\t Lowest MI: {lowest_mi:.3f}' ) if epoch == epochs - 1: return lowest_mi