def run(n_epoch): # train_set = StateTransitionsDataset(hdf5_file="c_swm/data/blocks-4-4-det_train.h5", n_obj=9) # test_set = StateTransitionsDataset(hdf5_file="c_swm/data/blocks-4-4-det_eval.h5", n_obj=9) # print("Training Examples: {}, Testing Examples: {}".format(len(train_set), len(test_set))) # train_set = StateTransitionsDataset(hdf5_file="c_swm/data/{}_all.h5".format(PREFIX), n_obj=OBJS+STACKS, remove_bg=REMOVE_BG) train_set = Concat([ StateTransitionsDataset( hdf5_file="c_swm/data/blocks-{}-{}-{}_all.h5".format( OBJS, STACKS, 0), n_obj=OBJS + STACKS, max_n_obj=MAX_N) for OBJS in TRAIN_DATASETS_OBJS ]) print("Training Examples: {}".format(len(train_set))) assert len(train_set) % TRAIN_BZ == 0 # assert len(test_set) % TEST_BZ == 0 train_loader = DataLoader(train_set, batch_size=TRAIN_BZ, shuffle=True, num_workers=4) # test_loader = DataLoader(test_set, batch_size=TEST_BZ, shuffle=True) vae = FoSae().to(device) optimizer = Adam(vae.parameters(), lr=1e-3) # optimizer = SGD(vae.parameters(), lr=1e-3) scheculer = LambdaLR(optimizer, lambda e: 1 if e < 100 else 0.1) best_loss = float('inf') best_tmp = 0 for e in range(n_epoch): temp = np.maximum(TEMP_BEGIN * np.exp(-ANNEAL_RATE * e), TEMP_MIN) print("Epoch: {}, Temperature: {}, Lr: {}".format( e, temp, scheculer.get_last_lr())) sys.stdout.flush() train_loss = epoch_routine(train_loader, vae, temp, optimizer) print('====> Epoch: {} Average train loss: {:.4f}'.format( e, train_loss)) test_loss = epoch_routine(train_loader, vae, temp) print( '====> Epoch: {} Average test loss: {:.4f}, Best Test loss: {:.4f}, Best temp {:.4f}' .format(e, test_loss, best_loss, best_tmp)) if test_loss < best_loss: print("Save Model") torch.save(vae.state_dict(), "fosae/model_{}/{}.pth".format(PREFIX, MODEL_NAME)) best_loss = test_loss best_tmp = temp scheculer.step()
def run(n_epoch): train_set, test_set, _, _ = get_train_and_test_dataset(*load_data()) print("Training Examples: {}, Testing Examples: {}".format( len(train_set), len(test_set))) assert len(train_set) % TRAIN_BZ == 0 assert len(test_set) % TEST_BZ == 0 train_loader = DataLoader(train_set, batch_size=TRAIN_BZ, shuffle=True) test_loader = DataLoader(test_set, batch_size=TEST_BZ, shuffle=False) vae = CubeSae(BACK_TO_LOGIT).to(device) # load_model(vae) optimizer = Adam(vae.parameters(), lr=1e-3) scheculer = LambdaLR(optimizer, lambda e: 1.0 if e < 300 else 0.1) best_loss = float('inf') best_epoch = 0 all_train_loss = [] all_validation_loss = [] for e in range(n_epoch): sys.stdout.flush() temp1 = np.maximum(TEMP_BEGIN_SAE * np.exp(-ANNEAL_RATE_SAE * e), TEMP_MIN_SAE) temp2 = np.maximum(TEMP_BEGIN_AAE * np.exp(-ANNEAL_RATE_AAE * e), TEMP_MIN_AAE) print("\n" + "-" * 50) print("Epoch: {}, Temperature: {:.2f} {:.2f}, Lr: {}".format( e, temp1, temp2, scheculer.get_last_lr())) train_loss = train(train_loader, vae, optimizer, (temp1, temp2, True)) validation_loss = test(test_loader, vae, e, (temp1, temp2, False)) all_train_loss.append(train_loss) all_validation_loss.append(validation_loss) print("\nBest test loss {:.5f} in epoch {}".format( best_loss, best_epoch)) if validation_loss < best_loss: print("Save model to {}".format(MODEL_PATH)) torch.save(vae.state_dict(), MODEL_PATH) best_loss = validation_loss best_epoch = e scheculer.step() plot_loss(all_train_loss, all_validation_loss, n_epoch, PLOT_DIR)
def train( run_name: str, # Data train_filepath: str = CSNJS_TRAIN_FILEPATH, eval_filepath: str = CSNJS_VALID_FILEPATH, spm_filepath: str = SPM_UNIGRAM_FILEPATH, program_mode="identity", eval_program_mode="identity", label_mode="identifier", num_workers=1, limit_dataset_size=-1, # Model model_type="transformer", n_decoder_layers=4, d_model: int = 512, resume_path: str = "", resume_encoder_name: str = "encoder_q", # encoder_q, encoder_k, encoder resume_project: bool = False, # Optimization train_decoder_only: bool = False, num_epochs: int = 50, save_every: int = 2, batch_size: int = 256, lr: float = 8e-4, adam_beta1: float = 0.9, adam_beta2: float = 0.98, use_lr_warmup: bool = True, loss_type = "nll_token", # nll_token or nll_sequence # Loss subword_regularization_alpha: float = 0, # Computational use_cuda: bool = True, auto_test: bool = True, seed: int = 0, ): """Train model""" torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) run_dir = RUN_DIR / run_name run_dir.mkdir(exist_ok=True, parents=True) logger.add(str((run_dir / "train.log").resolve())) logger.info(f"Saving logs, model checkpoints to {run_dir}") config = locals() logger.info(f"Config: {config}") wandb.init(name=run_name, config=config, job_type="training", project="identifier-prediction", entity="ml4code") if use_cuda: assert torch.cuda.is_available(), "CUDA not available. Check env configuration, or pass --use_cuda False" train_augmentations = [ {"fn": "sample_lines", "line_length_pct": 0.5}, {"fn": "insert_var_declaration", "prob": 0.5}, {"fn": "rename_variable", "prob": 0.5}, ] sp = spm.SentencePieceProcessor() sp.Load(spm_filepath) pad_id = sp.PieceToId("[PAD]") # Create training dataset and dataloader logger.info(f"Training data path {train_filepath}") train_dataset = get_csnjs_dataset(train_filepath, label_mode=label_mode, limit_size=limit_dataset_size) logger.info(f"Training dataset size: {len(train_dataset)}") train_loader = javascript_dataloader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, augmentations=train_augmentations, sp=sp, program_mode=program_mode, subword_regularization_alpha=subword_regularization_alpha, ) # Create eval dataset and dataloader logger.info(f"Eval data path {eval_filepath}") eval_dataset = get_csnjs_dataset(eval_filepath, label_mode=label_mode, limit_size=limit_dataset_size) logger.info(f"Eval dataset size: {len(eval_dataset)}") eval_loader = javascript_dataloader( eval_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, augmentations=[], sp=sp, program_mode=eval_program_mode, subword_regularization_alpha=subword_regularization_alpha, ) # Create model pad_id = sp.PieceToId("[PAD]") if model_type == "transformer": model = TransformerModel(n_tokens=sp.GetPieceSize(), pad_id=pad_id, n_decoder_layers=n_decoder_layers, d_model=d_model) logger.info(f"Created TransformerModel with {count_parameters(model)} params") elif model_type == "lstm": model = Seq2SeqLSTM(n_tokens=sp.GetPieceSize(), pad_id=pad_id, d_model=d_model) logger.info(f"Created Seq2SeqLSTM with {count_parameters(model)} params") # Load checkpoint if resume_path: logger.info(f"Resuming training from checkpoint {resume_path}, resume_encoder_name={resume_encoder_name}") checkpoint = torch.load(resume_path) pretrained_state_dict = checkpoint["model_state_dict"] encoder_state_dict = {} assert resume_encoder_name in ["encoder_k", "encoder_q", "encoder"] for key, value in pretrained_state_dict.items(): if key.startswith(resume_encoder_name + ".") and "project_layer" not in key: remapped_key = key[len(resume_encoder_name + ".") :] logger.debug(f"Remapping checkpoint key {key} to {remapped_key}. Value mean: {value.mean().item()}") encoder_state_dict[remapped_key] = value if key.startswith(resume_encoder_name + ".") and "project_layer.0." in key and resume_project: remapped_key = key[len(resume_encoder_name + ".") :] logger.debug(f"Remapping checkpoint project key {key} to {remapped_key}. Value mean: {value.mean().item()}") encoder_state_dict[remapped_key] = value model.encoder.load_state_dict(encoder_state_dict, strict=False) logger.info(f"Loaded state dict from {resume_path}") logger.info(f"Loaded keys: {encoder_state_dict.keys()}") # Set up optimizer model = nn.DataParallel(model) model = model.cuda() if use_cuda else model wandb.watch(model, log="all") params = model.module.decoder.parameters() if train_decoder_only else model.parameters() optimizer = torch.optim.Adam(params, lr=lr, betas=(adam_beta1, adam_beta2), eps=1e-9) if use_lr_warmup: scheduler = get_linear_schedule_with_warmup(optimizer, 5000, len(train_loader) * num_epochs) else: scheduler = LambdaLR(optimizer, lr_lambda=lambda x: 1.0) global_step = 0 min_eval_loss = float("inf") for epoch in tqdm.trange(1, num_epochs + 1, desc="training", unit="epoch", leave=False): logger.info(f"Starting epoch {epoch}\n") if train_decoder_only: model.module.encoder.eval() model.module.decoder.train() else: model.train() pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}") for X, Y, X_lengths, Y_lengths in pbar: if use_cuda: X = X.cuda() Y = Y.cuda() X_lengths, Y_lengths = X_lengths.cuda(), Y_lengths.cuda() optimizer.zero_grad() # NOTE: X and Y are [B, max_seq_len] tensors (batch first) logits = model(X, Y[:, :-1], X_lengths, Y_lengths) if loss_type == "nll_sequence": loss = F.cross_entropy(logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id, reduction='sum') loss = loss / X.size(0) # Average over num sequences, not target sequence lengths # Thus, minimize bits per sequence. elif loss_type == "nll_token": loss = F.cross_entropy(logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id,) loss.backward() optimizer.step() scheduler.step() # Log loss global_step += 1 wandb.log( {"epoch": epoch, f"label-{label_mode}/train_loss": loss.item(), "lr": scheduler.get_last_lr()[0]}, step=global_step ) pbar.set_description(f"epoch {epoch} loss {loss.item():.4f}") # Evaluate logger.info(f"Evaluating model after epoch {epoch} ({global_step} steps)...") max_decode_len = 20 if label_mode == "identifier" else 200 eval_loss = _evaluate(model, eval_loader, sp, use_cuda=use_cuda, max_decode_len=max_decode_len, loss_type=loss_type) logger.info(f"Evaluation loss after epoch {epoch} ({global_step} steps): {eval_loss:.4f}") wandb.log({"epoch": epoch, f"label-{label_mode}/eval_loss": eval_loss}, step=global_step) # Save checkpoint if save_every and epoch % save_every == 0 or eval_loss < min_eval_loss: checkpoint = { "model_state_dict": model.module.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "global_step": global_step, "config": config, "eval_loss": eval_loss, } if eval_loss < min_eval_loss: logger.info(f"New best evaluation loss: prev {min_eval_loss:.4f} > new {eval_loss:.4f}") min_eval_loss = eval_loss model_file = run_dir / "ckpt_best.pth" else: model_file = run_dir / f"ckpt_ep{epoch:04d}.pth" logger.info(f"Saving checkpoint to {model_file}...") torch.save(checkpoint, str(model_file.resolve())) wandb.save(str(model_file.resolve())) logger.info("Done.") if auto_test: best_ckpt = run_dir / "ckpt_best.pth" test( str(best_ckpt.resolve()), CSNJS_TEST_FILEPATH, spm_filepath, program_mode, label_mode, num_workers, -1, n_decoder_layers=n_decoder_layers, )
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform( args.train_resizing, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR( optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x))**(-args.lr_decay)) # define loss function domain_adv = DomainAdversarialLoss(domain_discri).to(device) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): print("lr:", lr_scheduler.get_last_lr()[0]) # train for one epoch train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close()
def train(rank, args, tokenizer, train_dataset, test_dataset, model_s, model_t, params_to_tune, head_importance=None, loss_num=-1, tune_iter=0): """ Train the model """ global train_count train_count += 1 world_size = 1 if rank < 0 else torch.distributed.get_world_size() if rank in [-1, 0]: printlog("Train stage: ", train_count) printlog(model_s) if head_importance is not None: head_mask = torch.ones(*list(head_importance.shape)).to(args.device) head_mask.requires_grad_(requires_grad=True) else: head_mask = None num_train_epochs = args.num_train_epochs if loss_num > 0: num_train_epochs = 0.25 #short train for incremental loss per_gpu_train_batch_size = args.per_gpu_train_batch_size train_batch_size = per_gpu_train_batch_size * world_size #get total batch size and if tune_iter > 0 and args.total_train_batch_size_for_tune: total_train_batch_size = args.total_train_batch_size_for_tune else: total_train_batch_size = args.total_train_batch_size gradient_accumulation_steps = total_train_batch_size // train_batch_size if tune_iter > 0 and args.learning_rate_for_tune: learning_rate = args.learning_rate_for_tune else: learning_rate = args.learning_rate if check_model_type(model_s, BertModelEMB): #use 2 datasets for embedding question and context separatly if rank in [-1, 0]: printlog("dataset_q size", len(train_dataset.q_dataset)) printlog("dataset_c size", len(train_dataset.c_dataset)) datasets = [train_dataset.q_dataset, train_dataset.c_dataset] else: if rank in [-1, 0]: printlog("dataset size", len(train_dataset)) datasets = [train_dataset] if rank > -1: #for distributed train use sample that take only part of samples for each process train_dataloaders = [ DataLoader(dataset, sampler=torch.utils.data.distributed.DistributedSampler( dataset, rank=rank), batch_size=per_gpu_train_batch_size) for dataset in datasets ] else: train_dataloaders = [ DataLoader(dataset, sampler=RandomSampler(dataset), batch_size=train_batch_size, num_workers=4) for dataset in datasets ] steps_per_epoch = sum(len(d) for d in train_dataloaders) steps_total = int(steps_per_epoch // gradient_accumulation_steps * num_train_epochs) # Prepare optimizer and scheduler name_set = set() for n, p in model_s.named_parameters(): if any(p is pp for pp in params_to_tune): name_set.add(n) named_params = [(n, p) for n, p in model_s.named_parameters() if n in name_set] if rank in [-1, 0]: for n, p in named_params: printlog('param for tune', n) def new_optimizer(): return AdamW([p for n, p in named_params], lr=learning_rate, eps=1e-08, weight_decay=0.0) optimizer = new_optimizer() def lr_lambda(current_step): p = float(current_step) / float(steps_total) warmup = 0.01 if p < warmup: return p / warmup p = (p - warmup) / (1 - warmup) return 1 if tune_iter == 0 else max(1 - p, 0) scheduler = LambdaLR(optimizer, lr_lambda) if rank in [-1, 0]: printlog("epoches", num_train_epochs) printlog("per_gpu_train_batch_size", per_gpu_train_batch_size) printlog("n_gpu", args.n_gpu) printlog("world_size", world_size) printlog("gradient_accumulation_steps", gradient_accumulation_steps) printlog("total train batch size", train_batch_size * gradient_accumulation_steps) printlog("steps_total", steps_total) restore_count = 0 if rank in [-1, 0]: if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) restore_file = os.path.join(args.output_dir, 'last_good_state.pth') restore_loss = None losses_list = [] global_step = 0 for epoch in range(math.ceil(num_train_epochs)): switch_to_train(rank, model_t) switch_to_train(rank, model_s) model_s.zero_grad() utils.sync_models(rank, model_s) time_last = time.time() for train_dataloader in train_dataloaders: printlog("rank", rank, "len(train_dataloader)", len(train_dataloader)) if rank > -1: train_dataloader.sampler.set_epoch(epoch) if len(train_dataloaders) > 1: # reset last loss to avoid restore due to dataset changing printlog("rank", rank, "reset restore_loss") restore_loss = None for step, batch in enumerate(train_dataloader): epoch_fp = epoch + step / len(train_dataloader) if epoch_fp > num_train_epochs: break inputs = { 'input_ids': batch[0].to(args.device), 'attention_mask': batch[1].to(args.device), 'token_type_ids': batch[2].to(args.device) } outputs_s = model_s(**inputs, head_mask=head_mask, output_hidden_states=True) losses = [] with torch.no_grad(): outputs_t = model_t(**inputs, output_hidden_states=True) out_s, out_t = outputs_s[-1], outputs_t[-1] assert len( out_s ) == model_s.config.num_hidden_layers + 1, "can not find hidden states in student model outputs" assert len( out_t ) == model_t.config.num_hidden_layers + 1, "can not find hidden states in teacher model outputs" if len(out_s) != len(out_t): #the student and teacher outputs are not aligned. try to find teacher output for each student output n_s, n_t = len(out_s), len(out_t) out_t = [ out_t[(i * (n_t - 1)) // (n_s - 1)] for i in range(n_s) ] assert len(out_s) == len( out_t ), "can not align number of outputs between student and teacher" assert all( s[0] == s[1] for s in zip(out_s[0].shape, out_t[0].shape) ), "output shapes for student and teacher are not the same" out_pairs = list(zip(out_s, out_t)) if loss_num > 0: out_pairs = out_pairs[:loss_num] losses += [(s - t.detach()).pow(2).mean() for s, t in out_pairs] losses_list.append([l.item() for l in losses]) if tune_iter == 0: loss = sum(losses) / len(losses) else: weights = [ args.loss_weight_alpha**i for i in range(len(losses)) ] losses_w = [w * l for w, l in zip(weights, losses)] loss = sum(losses_w) / sum(weights) if gradient_accumulation_steps > 1: loss = loss / gradient_accumulation_steps loss.backward() del out_s del out_t del outputs_s del outputs_t if head_importance is not None: #collect gradient statistics to find most valuable heads head_mask.grad.detach_() head_importance += (head_mask.grad.abs().detach() - head_importance) * 0.001 head_mask.grad.zero_() if (step + 1) % gradient_accumulation_steps == 0: global_step += 1 #sync gradients before calc step utils.sync_grads(rank, named_params, global_step == 1) torch.nn.utils.clip_grad_norm_( [p for n, p in named_params], 1) optimizer.step() scheduler.step() model_s.zero_grad() if (step + 1) % 50 == 0: str_out = "{} ep {:.2f} lrp {:.2f} rc {:02}".format( train_count, epoch_fp, np.log10(scheduler.get_last_lr()[0]), restore_count) ll = np.array(losses_list).mean(0) if rank > -1: #sync indicators llt = torch.tensor(ll).to(args.device) torch.distributed.all_reduce( llt, op=torch.distributed.ReduceOp.SUM) ll = llt.cpu().numpy() / float(world_size) loss = ll.mean() str_out += " loss {:.4f}".format(loss) losses_txt = ["{:.3f}".format(l) for l in ll] if tune_iter > 0: losses_txt = [ "{:.2f}x".format(w) + lt for w, lt in zip(weights, losses_txt) ] str_out += " ll " + " ".join(losses_txt) if time_last: dt_iter = (time.time() - time_last) / len(losses_list) dt_ep = dt_iter * steps_per_epoch str_out += " it {:.1f}s".format(dt_iter) str_out += " ep {:.1f}m".format(dt_ep / (60)) str_out += " eta {:.1f}h".format( dt_ep * (num_train_epochs - epoch_fp) / (60 * 60)) losses_list = [] time_last = time.time() if rank in [-1, 0]: logger.info(str_out) if rank > -1: #sync losses loss_tensor = torch.tensor([loss], device=args.device) torch.distributed.all_reduce( loss_tensor, op=torch.distributed.ReduceOp.SUM) loss = loss_tensor.item() / world_size if restore_loss is None or loss < restore_loss * 1.5: #good result lets save it restore_loss = loss if rank in [-1, 0]: torch.save( { 'model_state_dict': model_s.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, restore_file) if rank > -1: torch.distributed.barrier() else: #bad result lets restore restore_count += 1 logger.info( "rank {} restore #{} from {} with {} loss". format(rank, restore_count, restore_file, restore_loss)) checkpoint = torch.load(restore_file) model_s.load_state_dict( checkpoint['model_state_dict']) #optimizer.load_state_dict(checkpoint['optimizer_state_dict']) optimizer = new_optimizer() switch_to_train(rank, model_s) if loss_num <= 0: if rank in [-1, 0]: check_point_name = 'checkpoint-{:02}'.format(train_count) save_model(args, model_s, tokenizer, check_point_name) check_point_name = check_point_name + '-{:02}'.format(epoch + 1) switch_to_eval(rank, model_s) result_s = evaluate(args, model_s, test_dataset) for k, v in result_s.items(): logger.info("{} {} {}".format(check_point_name, k, v)) if rank > -1: torch.distributed.barrier() if rank in [-1, 0]: if os.path.exists(restore_file): os.remove(restore_file)
def train(model, tokenizer, train_data, valid_data, args, eos=False): print('eos:', eos) model.train() train_dataset = TextDataset(train_data) train_dataloader = DataLoader( train_dataset, sampler=RandomSampler(train_dataset), batch_size=args.train_batch_size, num_workers=args.num_workers, collate_fn=lambda x: collate_fn(x, tokenizer, args.max_seq_length, eos=eos, add_noise=args.add_noise, tokenizer_type=args.tokenizer)) valid_dataset = TextDataset(valid_data) valid_dataloader = DataLoader( valid_dataset, sampler=SequentialSampler(valid_dataset), batch_size=args.eval_batch_size, num_workers=args.num_workers, collate_fn=lambda x: collate_fn(x, tokenizer, args.max_seq_length, eos=eos, tokenizer_type=args.tokenizer)) valid_noisy = [x['noisy'] for x in valid_data] valid_clean = [x['clean'] for x in valid_data] epochs = (args.max_steps - 1) // len(train_dataloader) + 1 optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=eval(args.adam_betas), eps=args.eps, weight_decay=args.weight_decay) lr_lambda = lambda x: x / args.num_warmup_steps if x <= args.num_warmup_steps else ( x / args.num_warmup_steps)**-0.5 scheduler = LambdaLR(optimizer, lr_lambda) if not args.step_lr else StepLR( optimizer, args.eval_interval, args.step_gamma) step = 0 best_val_gleu = -float("inf") meter = Meter() for epoch in range(1, epochs + 1): print("===EPOCH: ", epoch) for batch in train_dataloader: if step % args.eval_interval == 0: start_eval = time.time() (val_loss, val_loss_token), valid_str = evaluate(model, valid_dataloader, args) if args.beamsearch: prediction = correct_beam(model, tokenizer, valid_noisy, args, eos=eos, length_limit=0.15) else: prediction = correct(model, tokenizer, valid_noisy, args, eos=eos, length_limit=0.15) val_em = em(prediction, valid_clean) cnt = 0 for noisy, pred, clean in zip(valid_noisy, prediction, valid_clean): print(f'[{noisy}], [{pred}], [{clean}]') # 30개 출력하기 cnt += 1 if cnt == 30: break val_gleu = gleu(prediction, valid_clean) logger.info('-' * 89) logger.info( f' [{step:6d}] valid | {valid_str} | em {val_em:5.2f} | gleu {val_gleu:5.2f}' ) logger.info('-' * 89) nsml.report(step=step, scope=locals(), summary=True, valid__loss_sent=val_loss, valid__token_ppl=math.exp(val_loss_token), valid__em=val_em, valid__gleu=val_gleu) # if step % (args.eval_interval * 5) == 0: # by 5000 steps # nsml.save(step) if val_gleu > best_val_gleu: best_val_gleu = val_gleu nsml.save('best') if val_gleu >= 86.0: nsml.save(step) if args.mode == 'pretrain': torch.save(model.state_dict(), args.model_name) meter.start += time.time() - start_eval model.train() step += 1 batch = tuple(t.to(args.device) for t in batch) loss, items = calc_loss(model, batch) meter.add(*items) loss.backward() if args.max_grad_norm > 0: nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() model.zero_grad() scheduler.step() if step % args.log_interval == 0: lr = scheduler.get_last_lr()[0] loss_sent, loss_token = meter.average() logger.info( f' [{step:5d}] lr {lr:.6f} | {meter.print_str(True)}') nsml.report(step=step, scope=locals(), summary=True, train__lr=lr, train__loss_sent=loss_sent, train__token_ppl=math.exp(loss_token)) meter.init() if step >= args.max_steps: break if step >= args.max_steps: break
def main(config, progress): # save config with open("./log/configs.json", "a") as f: json.dump(config, f) f.write("\n") cprint("*"*80) cprint("Experiment progress: {0:.2f}%".format(progress*100)) cprint("*"*80) metrics = {} # data hyper-params data_path = config["data_path"] keyword_path = config["keyword_path"] pretrained_wordvec_path = config["pretrained_wordvec_path"] data_dir = "/".join(data_path.split("/")[:-1]) dataset = data_path.split("/")[-2] # convai2 or casual test_mode = bool(config["test_mode"]) save_model_path = config["save_model_path"] min_context_len = config["min_context_len"] max_context_len = config["max_context_len"] max_sent_len = config["max_sent_len"] max_keyword_len = config["max_keyword_len"] max_vocab_size = config["max_vocab_size"] max_keyword_vocab_size = config["max_keyword_vocab_size"] remove_self_loop = bool(config["remove_self_loop"]) # model hyper-params config_id = config["config_id"] model = config["model"] gnn = config["gnn"] aggregation = config["aggregation"] utterance_encoder = config["utterance_encoder"] use_last_k_utterances = config["use_last_k_utterances"] use_CN_hopk_graph = config["use_CN_hopk_graph"] use_utterance_concepts = bool(config["use_utterance_concepts"]) combine_node_emb = config["combine_node_emb"] # replace, mean, max, concat, concept_encoder = config["concept_encoder"] embed_size = config["embed_size"] use_pretrained_word_embedding = bool(config["use_pretrained_word_embedding"]) fix_word_embedding = bool(config["fix_word_embedding"]) hidden_size = config["hidden_size"] n_layers = config["n_layers"] bidirectional = bool(config["bidirectional"]) n_heads = config["n_heads"] dropout = config["dropout"] # training hyper-params batch_size = config["batch_size"] epochs = config["epochs"] lr = config["lr"] lr_decay = config["lr_decay"] seed = config["seed"] device = torch.device(config["device"]) fp16 = bool(config["fp16"]) fp16_opt_level = config["fp16_opt_level"] # set seed random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) if "convai2" in data_dir and min_context_len != 2: raise ValueError("convai2 dataset has min context len of 2") if use_pretrained_word_embedding and str(embed_size) not in pretrained_wordvec_path: raise ValueError("embedding size and pretrained_wordvec_path not match") # load data cprint("Loading conversation data...") train, valid, test = load_pickle(data_path) train_keyword, valid_keyword, test_keyword = load_pickle(keyword_path) if test_mode: cprint("Testing model...") train = train + valid train_keyword = train_keyword + valid_keyword valid = test valid_keyword = test_keyword cprint(len(train), len(train_keyword), len(valid), len(valid_keyword)) cprint("sample train: ", train[0]) cprint("sample train keyword: ", train_keyword[0]) cprint("sample valid: ", valid[0]) cprint("sample valid keyword: ", valid_keyword[0]) # clip and pad data train_padded_convs, train_padded_keywords = pad_and_clip_data(train, train_keyword, min_context_len, max_context_len+1, max_sent_len, max_keyword_len) valid_padded_convs, valid_padded_keywords = pad_and_clip_data(valid, valid_keyword, min_context_len, max_context_len+1, max_sent_len, max_keyword_len) cprint(len(train_padded_convs), len(train_padded_keywords), len(valid_padded_convs), len(valid_padded_keywords)) cprint("sample padded train: ", train_padded_convs[0]) cprint("sample padded train keyword: ", train_padded_keywords[0]) cprint("sample padded valid: ", valid_padded_convs[0]) cprint("sample padded valid keyword: ", valid_padded_keywords[0]) # build vocab if "convai2" in data_dir: test_padded_convs, _ = pad_and_clip_data(test, test_keyword, min_context_len, max_context_len+1, max_sent_len, max_keyword_len) word2id = build_vocab(train_padded_convs + valid_padded_convs + test_padded_convs, max_vocab_size) # use entire dataset for vocab as done in (tang 2019) else: word2id = build_vocab(train_padded_convs, max_vocab_size) keyword2id = build_vocab(train_padded_keywords, max_keyword_vocab_size) id2keyword = {idx:w for w, idx in keyword2id.items()} for w in keyword2id: if w not in word2id: word2id[w] = len(word2id) # add OOV keywords to word2id id2word = {idx:w for w, idx in word2id.items()} keywordid2wordid = [word2id[id2keyword[i]] if id2keyword[i] in word2id else word2id["<unk>"] for i in range(len(keyword2id))] vocab_size = len(word2id) keyword_vocab_size = len(keyword2id) cprint("vocab size: ", vocab_size) cprint("keyword vocab size: ", keyword_vocab_size) CN_hopk_edge_index, CN_hopk_nodeid2wordid, keywordid2nodeid, node2id = None, None, None, None keyword_mask_matrix = None if use_CN_hopk_graph > 0: cprint("Loading CN_hopk edge index...") """ CN_graph_dict: { edge_index: 2D list (num_edges, 2), edge_type: list (num_edges, ), edge_weight: list (num_edges, ), relation2id: {}, nodeid2wordid: 2D list (num_nodes, 10) } """ CN_hopk_graph_path = "./data/{0}/CN_graph_{1}hop_ge1.pkl".format(dataset, use_CN_hopk_graph) cprint("Loading graph from ", CN_hopk_graph_path) CN_hopk_graph_dict = load_nx_graph_hopk(CN_hopk_graph_path, word2id, keyword2id) CN_hopk_edge_index = torch.LongTensor(CN_hopk_graph_dict["edge_index"]).transpose(0,1).to(device) # (2, num_edges) CN_hopk_nodeid2wordid = torch.LongTensor(CN_hopk_graph_dict["nodeid2wordid"]).to(device) # (num_nodes, 10) node2id = CN_hopk_graph_dict["node2id"] id2node = {idx:w for w,idx in node2id.items()} keywordid2nodeid = [node2id[id2keyword[i]] if id2keyword[i] in node2id else node2id["<unk>"] for i in range(len(keyword2id))] keywordid2nodeid = torch.LongTensor(keywordid2nodeid).to(device) keyword_mask_matrix = torch.from_numpy(CN_hopk_graph_dict["edge_mask"]).float() # numpy array of (keyword_vocab_size, keyword_vocab_size) cprint("building keyword mask matrix...") if remove_self_loop: keyword_mask_matrix[torch.arange(keyword_vocab_size), torch.arange(keyword_vocab_size)] = 0 cprint("keyword mask matrix non-zeros ratio: ", keyword_mask_matrix.mean()) cprint("average number of neighbors: ", keyword_mask_matrix.sum(dim=1).mean()) cprint("sample keyword mask matrix: ", keyword_mask_matrix[:8,:8]) keyword_mask_matrix = keyword_mask_matrix.to(device) cprint("edge index shape: ", CN_hopk_edge_index.shape) cprint("edge index[:,:8]", CN_hopk_edge_index[:,:8]) cprint("nodeid2wordid shape: ", CN_hopk_nodeid2wordid.shape) cprint("nodeid2wordid[:5,:8]", CN_hopk_nodeid2wordid[:5,:8]) cprint("keywordid2nodeid shape: ", keywordid2nodeid.shape) cprint("keywordid2nodeid[:8]", keywordid2nodeid[:8]) # convert edge index if utterance_encoder != "": keywordid2wordid = torch.LongTensor(keywordid2wordid).to(device) cprint("keywordid2wordid shape: ", keywordid2wordid.shape) cprint("keywordid2wordid", keywordid2wordid[:8]) # convert tokens to ids train_conv_ids = convert_convs_to_ids(train_padded_convs, word2id) valid_conv_ids = convert_convs_to_ids(valid_padded_convs, word2id) train_keyword_ids = convert_convs_to_ids(train_padded_keywords, keyword2id) valid_keyword_ids = convert_convs_to_ids(valid_padded_keywords, keyword2id) cprint(len(train_conv_ids), len(train_keyword_ids), len(valid_conv_ids), len(valid_keyword_ids)) cprint("sample train token ids: ", train_conv_ids[0]) cprint("sample train keyword ids: ", train_keyword_ids[0]) cprint("sample valid token ids: ", valid_conv_ids[0]) cprint("sample valid keyword ids: ", valid_keyword_ids[0]) num_examples = len(train_keyword_ids) # create model if model in ["KW_GNN"]: model_kwargs = { "embed_size": embed_size, "vocab_size": vocab_size, "keyword_vocab_size": keyword_vocab_size, "hidden_size": hidden_size, "output_size": hidden_size, "n_layers": n_layers, "gnn": gnn, "aggregation": aggregation, "n_heads": n_heads, "dropout": dropout, "bidirectional": bidirectional, "utterance_encoder": utterance_encoder, "keywordid2wordid": keywordid2wordid, "keyword_mask_matrix": keyword_mask_matrix, "nodeid2wordid": CN_hopk_nodeid2wordid, "keywordid2nodeid": keywordid2nodeid, "concept_encoder": concept_encoder, "combine_node_emb": combine_node_emb } cprint("Building model...") model = globals()[config["model"]](**model_kwargs) # cprint(model.edge_weight.shape, model.edge_weight.requires_grad) pretrained_word_embedding = None if use_pretrained_word_embedding: # load pretrained word embedding cprint("Loading pretrained word embeddings...") pretrained_wordvec_name = pretrained_wordvec_path.split("/")[-1][:-4] word_vectors_path = os.path.join(data_dir, "word_vectors_{0}.pkl".format(pretrained_wordvec_name)) keyword2id = word2id if os.path.exists(word_vectors_path): cprint("Loading pretrained word embeddings from ", word_vectors_path) with open(word_vectors_path, "rb") as f: word_vectors = pickle.load(f) else: cprint("Loading pretrained word embeddings from scratch...") word_vectors = load_vectors(pretrained_wordvec_path, keyword2id) cprint("Saving pretrained word embeddings to ", word_vectors_path) with open(word_vectors_path, "wb") as f: pickle.dump(word_vectors, f) print("loaded word vector size: ", len(word_vectors)) pretrained_word_embedding = np.zeros((len(keyword2id), embed_size)) for w, i in keyword2id.items(): if w in word_vectors: pretrained_word_embedding[i] = np.array(word_vectors[w]) else: pretrained_word_embedding[i] = np.random.randn(embed_size)/9 pretrained_word_embedding[0] = 0 # 0 for PAD embedding pretrained_word_embedding = torch.from_numpy(pretrained_word_embedding).float() cprint("word embedding size: ", pretrained_word_embedding.shape) model.init_embedding(pretrained_word_embedding, fix_word_embedding) cprint(model) cprint("number of parameters: ", count_parameters(model)) model.to(device) # optimization amp = None if fp16: from apex import amp optimizer = torch.optim.Adam(model.parameters(), lr=lr) # scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: lr_decay ** epoch) scheduler = LambdaLR(optimizer, lr_lambda=lambda step: 1/(1+lr_decay*step/(num_examples/batch_size))) if fp16: model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) # training epoch_train_losses = [] epoch_valid_losses = [] epoch_valid_precisions = [] epoch_valid_recalls = [] best_model_statedict = {} cprint("Start training...") for epoch in range(epochs): cprint("-"*80) cprint("Epoch", epoch+1) train_batches = create_batches_keyword_prediction(train_conv_ids, train_keyword_ids, 2*max_keyword_len, batch_size, \ shuffle=True, remove_self_loop=remove_self_loop, keywordid2wordid=keywordid2wordid, \ keyword_mask_matrix=keyword_mask_matrix.cpu().numpy(), use_last_k_utterances=use_last_k_utterances, use_utterance_concepts=use_utterance_concepts, \ keyword2id=keyword2id, node2id=node2id, id2word=id2word) valid_batches = create_batches_keyword_prediction(valid_conv_ids, valid_keyword_ids, 2*max_keyword_len, batch_size, \ shuffle=False, remove_self_loop=remove_self_loop, keywordid2wordid=keywordid2wordid, \ keyword_mask_matrix=keyword_mask_matrix.cpu().numpy(), use_last_k_utterances=use_last_k_utterances, use_utterance_concepts=use_utterance_concepts, \ keyword2id=keyword2id, node2id=node2id, id2word=id2word) cprint("train batches 1st example: ") for k, v in train_batches[0].items(): if k == "batch_X_keywords": cprint(k, v[0], [id2keyword[w] for w in v[0]]) if k == "batch_X_utterances": utters = [] for utter in v[0]: utters.append([id2word[w] for w in utter]) cprint(k, v[0], utters) if k == "batch_X_concepts" and len(v) > 0: cprint(k, v[0], [id2node[w] for w in v[0]]) if k == "batch_y": cprint(k, v[0], [id2keyword[w] for w in v[0]]) model.train() train_loss, (train_precision, train_recall) = run_epoch(train_batches, model, optimizer, epoch=epoch, training=True, device=device, \ fp16=fp16, amp=amp, step_scheduler=scheduler, keyword_mask_matrix=keyword_mask_matrix, keywordid2wordid=keywordid2wordid, \ CN_hopk_edge_index=CN_hopk_edge_index, use_utterance_concepts=use_utterance_concepts) cprint("Config id: {}, Epoch {}: train precision: {}, train recall: {}" .format(config_id, epoch+1, train_precision, train_recall)) model.eval() valid_loss, (valid_precision, valid_recall) = run_epoch(valid_batches, model, optimizer, epoch=epoch, training=False, device=device, \ keyword_mask_matrix=keyword_mask_matrix, keywordid2wordid=keywordid2wordid, \ CN_hopk_edge_index=CN_hopk_edge_index, use_utterance_concepts=use_utterance_concepts) # scheduler.step() cprint("Config id: {}, Epoch {}: train loss: {}, valid loss: {}, valid precision: {}, valid recall: {}" .format(config_id, epoch+1, train_loss, valid_loss, valid_precision, valid_recall)) if scheduler is not None: cprint("Current learning rate: ", scheduler.get_last_lr()) epoch_train_losses.append(train_loss) epoch_valid_losses.append(valid_loss) epoch_valid_precisions.append(valid_precision) epoch_valid_recalls.append(valid_recall) if save_model_path != "": if epoch == 0: for k, v in model.state_dict().items(): best_model_statedict[k] = v.cpu() else: if epoch_valid_recalls[-1][0] == max([recall1 for recall1, _, _ in epoch_valid_recalls]): for k, v in model.state_dict().items(): best_model_statedict[k] = v.cpu() # early stopping if len(epoch_valid_recalls) >= 3 and epoch_valid_recalls[-1][0] < epoch_valid_recalls[-2][0] and epoch_valid_recalls[-2][0] < epoch_valid_recalls[-3][0]: break config.pop("seed") config.pop("config_id") metrics["config"] = config metrics["score"] = max([recall[0] for recall in epoch_valid_recalls]) metrics["epoch"] = np.argmax([recall[0] for recall in epoch_valid_recalls]).item() metrics["recall"] = epoch_valid_recalls[metrics["epoch"]] metrics["precision"] = epoch_valid_precisions[metrics["epoch"]] if save_model_path: cprint("Saving model to ", save_model_path) best_model_statedict["word2id"] = keyword2id best_model_statedict["model_kwargs"] = model_kwargs torch.save(best_model_statedict, save_model_path) return metrics
def main(config, progress): # save config with open("./log/configs.json", "a") as f: json.dump(config, f) f.write("\n") cprint("*" * 80) cprint("Experiment progress: {0:.2f}%".format(progress * 100)) cprint("*" * 80) metrics = {} # data hyper-params data_path = config["data_path"] keyword_path = config["keyword_path"] pretrained_wordvec_path = config["pretrained_wordvec_path"] data_dir = "/".join(data_path.split("/")[:-1]) dataset = data_path.split("/")[-2] # convai2 or casual test_mode = bool(config["test_mode"]) save_model_path = config["save_model_path"] load_kw_prediction_path = config["load_kw_prediction_path"] min_context_len = config["min_context_len"] max_context_len = config["max_context_len"] max_sent_len = config["max_sent_len"] max_keyword_len = config["max_keyword_len"] max_vocab_size = config["max_vocab_size"] max_keyword_vocab_size = config["max_keyword_vocab_size"] flatten_context = config["flatten_context"] # model hyper-params config_id = config["config_id"] model = config["model"] use_CN_hopk_graph = config["use_CN_hopk_graph"] use_utterance_concepts = use_CN_hopk_graph > 0 concept_encoder = config["concept_encoder"] combine_word_concepts = config["combine_word_concepts"] gnn = config["gnn"] encoder = config["encoder"] aggregation = config["aggregation"] use_keywords = bool(config["use_keywords"]) keyword_score_weight = config["keyword_score_weight"] keyword_encoder = config["keyword_encoder"] # mean, max, GRU, any_max embed_size = config["embed_size"] use_pretrained_word_embedding = bool( config["use_pretrained_word_embedding"]) fix_word_embedding = bool(config["fix_word_embedding"]) gnn_hidden_size = config["gnn_hidden_size"] gnn_layers = config["gnn_layers"] encoder_hidden_size = config["encoder_hidden_size"] encoder_layers = config["encoder_layers"] n_heads = config["n_heads"] dropout = config["dropout"] # training hyper-params batch_size = config["batch_size"] epochs = config["epochs"] lr = config["lr"] lr_decay = config["lr_decay"] seed = config["seed"] device = torch.device(config["device"]) fp16 = bool(config["fp16"]) fp16_opt_level = config["fp16_opt_level"] # set seed random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) if "convai2" in data_dir and min_context_len != 2: raise ValueError("convai2 dataset has min context len of 2") if use_pretrained_word_embedding and str( embed_size) not in pretrained_wordvec_path: raise ValueError( "embedding size and pretrained_wordvec_path not match") if use_keywords and load_kw_prediction_path == "": raise ValueError( "kw model path needs to be provided when use_keywords is True") # load data cprint("Loading conversation data...") train, valid, test = load_pickle(data_path) train_keyword, valid_keyword, test_keyword = load_pickle(keyword_path) train_candidate, valid_candidate = None, None # load 20 candidates train_candidate, valid_candidate, test_candidate = load_pickle( os.path.join(data_dir, "candidate.pkl")) if test_mode: cprint("Testing model...") train = train + valid train_keyword = train_keyword + valid_keyword valid = test valid_keyword = test_keyword train_candidate = train_candidate + valid_candidate valid_candidate = test_candidate cprint("sample train: ", train[0]) cprint("sample train keyword: ", train_keyword[0]) cprint("sample valid: ", valid[0]) cprint("sample valid keyword: ", valid_keyword[0]) # clip and pad data train_padded_convs, train_padded_keywords = pad_and_clip_data( train, train_keyword, min_context_len, max_context_len + 1, max_sent_len, max_keyword_len) valid_padded_convs, valid_padded_keywords = pad_and_clip_data( valid, valid_keyword, min_context_len, max_context_len + 1, max_sent_len, max_keyword_len) train_padded_candidates = pad_and_clip_candidate(train_candidate, max_sent_len) valid_padded_candidates = pad_and_clip_candidate(valid_candidate, max_sent_len) # build vocab if "convai2" in data_dir: test_padded_convs, _ = pad_and_clip_data(test, test_keyword, min_context_len, max_context_len + 1, max_sent_len, max_keyword_len) word2id = build_vocab(train_padded_convs + valid_padded_convs + test_padded_convs, max_vocab_size) # use entire dataset for vocab else: word2id = build_vocab(train_padded_convs, max_vocab_size) keyword2id = build_vocab(train_padded_keywords, max_keyword_vocab_size) id2keyword = {idx: w for w, idx in keyword2id.items()} for w in keyword2id: if w not in word2id: word2id[w] = len(word2id) # add OOV keywords to word2id id2word = {idx: w for w, idx in word2id.items()} cprint("keywords that are not in word2id: ", set(keyword2id.keys()) - set(word2id.keys())) vocab_size = len(word2id) keyword_vocab_size = len(keyword2id) cprint("vocab size: ", vocab_size) cprint("keyword vocab size: ", keyword_vocab_size) # create a mapping from keyword id to word id keywordid2wordid = None train_candidate_keyword_ids, valid_candidate_keyword_ids = None, None if use_keywords: keywordid2wordid = [ word2id[id2keyword[i]] if id2keyword[i] in word2id else word2id["<unk>"] for i in range(len(keyword2id)) ] keywordid2wordid = torch.LongTensor(keywordid2wordid).to(device) # load candidate keywords candidate_keyword_path = os.path.join(data_dir, "candidate_keyword.pkl") if os.path.exists(candidate_keyword_path): cprint("Loading candidate keywords from ", candidate_keyword_path) train_candidate_keywords, valid_candidate_keywords, test_candidate_keywords = load_pickle( candidate_keyword_path) else: cprint("Creating candidate keywords...") train_candidate_keywords = extract_keywords_from_candidates( train_candidate, keyword2id) valid_candidate_keywords = extract_keywords_from_candidates( valid_candidate, keyword2id) test_candidate_keywords = extract_keywords_from_candidates( test_candidate, keyword2id) save_pickle((train_candidate_keywords, valid_candidate_keywords, test_candidate_keywords), candidate_keyword_path) if test_mode: train_candidate_keywords = train_candidate_keywords + valid_candidate_keywords valid_candidate_keywords = test_candidate_keywords # pad cprint("Padding candidate keywords...") train_padded_candidate_keywords = pad_and_clip_candidate( train_candidate_keywords, max_keyword_len) valid_padded_candidate_keywords = pad_and_clip_candidate( valid_candidate_keywords, max_keyword_len) # convert candidates to ids cprint("Converting candidate keywords to ids...") train_candidate_keyword_ids = convert_candidates_to_ids( train_padded_candidate_keywords, keyword2id) valid_candidate_keyword_ids = convert_candidates_to_ids( valid_padded_candidate_keywords, keyword2id) # load CN graph CN_hopk_edge_index, CN_hopk_nodeid2wordid, keywordid2nodeid, node2id, CN_hopk_edge_matrix_mask = None, None, None, None, None if use_CN_hopk_graph > 0: cprint("Loading CN_hopk edge index...") """ CN_graph_dict: { edge_index: 2D list (num_edges, 2), edge_weight: list (num_edges, ), nodeid2wordid: 2D list (num_nodes, 10), edge_mask: numpy array of (keyword_vocab_size, keyword_vocab_size) } """ CN_hopk_graph_path = "./data/{0}/CN_graph_{1}hop_ge1.pkl".format( dataset, use_CN_hopk_graph) cprint("Loading graph from ", CN_hopk_graph_path) CN_hopk_graph_dict = load_nx_graph_hopk(CN_hopk_graph_path, word2id, keyword2id) CN_hopk_edge_index = torch.LongTensor( CN_hopk_graph_dict["edge_index"]).transpose(0, 1).to( device) # (2, num_edges) CN_hopk_nodeid2wordid = torch.LongTensor( CN_hopk_graph_dict["nodeid2wordid"]).to(device) # (num_nodes, 10) node2id = CN_hopk_graph_dict["node2id"] id2node = {idx: w for w, idx in node2id.items()} keywordid2nodeid = [ node2id[id2keyword[i]] if id2keyword[i] in node2id else node2id["<unk>"] for i in range(len(keyword2id)) ] keywordid2nodeid = torch.LongTensor(keywordid2nodeid).to(device) cprint("edge index shape: ", CN_hopk_edge_index.shape) cprint("edge index[:,:8]", CN_hopk_edge_index[:, :8]) cprint("nodeid2wordid shape: ", CN_hopk_nodeid2wordid.shape) cprint("nodeid2wordid[:5,:8]", CN_hopk_nodeid2wordid[:5, :8]) cprint("keywordid2nodeid shape: ", keywordid2nodeid.shape) cprint("keywordid2nodeid[:8]", keywordid2nodeid[:8]) # convert tokens to ids train_conv_ids = convert_convs_to_ids(train_padded_convs, word2id) valid_conv_ids = convert_convs_to_ids(valid_padded_convs, word2id) train_keyword_ids = convert_convs_to_ids(train_padded_keywords, keyword2id) valid_keyword_ids = convert_convs_to_ids(valid_padded_keywords, keyword2id) train_candidate_ids, valid_candidate_ids = None, None train_candidate_ids = convert_candidates_to_ids(train_padded_candidates, word2id) valid_candidate_ids = convert_candidates_to_ids(valid_padded_candidates, word2id) keyword_mask_matrix = None if use_CN_hopk_graph > 0: keyword_mask_matrix = torch.from_numpy( CN_hopk_graph_dict["edge_mask"]).float( ) # numpy array of (keyword_vocab_size, keyword_vocab_size) cprint("building keyword mask matrix...") keyword_mask_matrix[ torch.arange(keyword_vocab_size), torch.arange(keyword_vocab_size)] = 0 # remove self loop cprint("keyword mask matrix non-zeros ratio: ", keyword_mask_matrix.mean()) cprint("average number of neighbors: ", keyword_mask_matrix.sum(dim=1).mean()) cprint("sample keyword mask matrix: ", keyword_mask_matrix[:8, :8]) keyword_mask_matrix = keyword_mask_matrix.to(device) num_examples = len(train_conv_ids) cprint("sample train token ids: ", train_conv_ids[0]) cprint("sample train keyword ids: ", train_keyword_ids[0]) cprint("sample valid token ids: ", valid_conv_ids[0]) cprint("sample valid keyword ids: ", valid_keyword_ids[0]) cprint("sample train candidate ids: ", train_candidate_ids[0]) cprint("sample valid candidate ids: ", valid_candidate_ids[0]) if use_keywords: cprint("sample train candidate keyword ids: ", train_candidate_keyword_ids[0]) cprint("sample valid candidate keyword ids: ", valid_candidate_keyword_ids[0]) # create model if model in ["CoGraphMatcher"]: model_kwargs = { "embed_size": embed_size, "vocab_size": vocab_size, "gnn_hidden_size": gnn_hidden_size, "gnn_layers": gnn_layers, "encoder_hidden_size": encoder_hidden_size, "encoder_layers": encoder_layers, "n_heads": n_heads, "CN_hopk_edge_matrix_mask": CN_hopk_edge_matrix_mask, "nodeid2wordid": CN_hopk_nodeid2wordid, "keywordid2wordid": keywordid2wordid, "keywordid2nodeid": keywordid2nodeid, "concept_encoder": concept_encoder, "gnn": gnn, "encoder": encoder, "aggregation": aggregation, "use_keywords": use_keywords, "keyword_score_weight": keyword_score_weight, "keyword_encoder": keyword_encoder, "dropout": dropout, "combine_word_concepts": combine_word_concepts } # create keyword model kw_model = "" use_last_k_utterances = -1 if use_keywords: kw_model = load_kw_prediction_path.split( "/")[-1][:-3] # keyword prediction model name if "GNN" in kw_model: kw_model = "KW_GNN" use_last_k_utterances = 2 # load pretrained model cprint("Loading weights from ", load_kw_prediction_path) kw_model_checkpoint = torch.load(load_kw_prediction_path, map_location=device) if "word2id" in kw_model_checkpoint: keyword2id = kw_model_checkpoint.pop("word2id") if "model_kwargs" in kw_model_checkpoint: kw_model_kwargs = kw_model_checkpoint.pop("model_kwargs") kw_model = globals()[kw_model](**kw_model_kwargs) kw_model.load_state_dict(kw_model_checkpoint) kw_model.to(device) kw_model.eval() # set to evaluation mode, no training required cprint("Building model...") model = globals()[config["model"]](**model_kwargs) cprint("Initializing pretrained word embeddings...") pretrained_word_embedding = None if use_pretrained_word_embedding: # load pretrained word embedding cprint("Loading pretrained word embeddings...") pretrained_wordvec_name = pretrained_wordvec_path.split("/")[-1][:-4] word_vectors_path = os.path.join( data_dir, "word_vectors_{0}.pkl".format(pretrained_wordvec_name)) if os.path.exists(word_vectors_path): cprint("Loading pretrained word embeddings from ", word_vectors_path) with open(word_vectors_path, "rb") as f: word_vectors = pickle.load(f) else: cprint("Loading pretrained word embeddings from scratch...") word_vectors = load_vectors(pretrained_wordvec_path, word2id) cprint("Saving pretrained word embeddings to ", word_vectors_path) with open(word_vectors_path, "wb") as f: pickle.dump(word_vectors, f) cprint("pretrained word embedding size: ", len(word_vectors)) pretrained_word_embedding = np.zeros((len(word2id), embed_size)) for w, i in word2id.items(): if w in word_vectors: pretrained_word_embedding[i] = np.array(word_vectors[w]) else: pretrained_word_embedding[i] = np.random.randn(embed_size) / 9 pretrained_word_embedding[0] = 0 # 0 for PAD embedding pretrained_word_embedding = torch.from_numpy( pretrained_word_embedding).float() cprint("word embedding size: ", pretrained_word_embedding.shape) model.init_embedding(pretrained_word_embedding, fix_word_embedding) cprint(model) cprint("number of parameters: ", count_parameters(model)) model.to(device) # optimization amp = None if fp16: from apex import amp optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = LambdaLR(optimizer, lr_lambda=lambda step: 1 / (1 + lr_decay * step / (num_examples / batch_size))) if fp16: model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) # training epoch_train_losses = [] epoch_valid_losses = [] epoch_valid_precisions = [] epoch_valid_recalls = [] epoch_valid_MRRs = [] best_model_statedict = {} cprint("Start training...") for epoch in range(epochs): cprint("-" * 80) cprint("Epoch", epoch + 1) train_batches = create_batches_retrieval(train_conv_ids, train_keyword_ids, train_candidate_ids, train_candidate_keyword_ids, \ 2*max_keyword_len, batch_size, shuffle=True, use_keywords=use_keywords, use_candidate_keywords=use_keywords, use_utterance_concepts=use_utterance_concepts, \ node2id=node2id, id2word=id2word, flatten_context=flatten_context, use_last_k_utterances=use_last_k_utterances) valid_batches = create_batches_retrieval(valid_conv_ids, valid_keyword_ids, valid_candidate_ids, valid_candidate_keyword_ids, \ 2*max_keyword_len, batch_size, shuffle=False, use_keywords=use_keywords, use_candidate_keywords=use_keywords, use_utterance_concepts=use_utterance_concepts, \ node2id=node2id, id2word=id2word, flatten_context=flatten_context, use_last_k_utterances=use_last_k_utterances) if epoch == 0: cprint("number of optimization steps per epoch: ", len(train_batches)) # 3361 cprint("train batches 1st example: ") for k, v in train_batches[0].items(): if k == "batch_context": utters = [] for utter in v[0]: utters.append([id2word[w] for w in utter]) cprint("\n", k, v[0], utters) if k == "batch_candidates": utters = [] for utter in v[0]: utters.append([id2word[w] for w in utter]) cprint("\n", k, v[0], utters) if k == "batch_context_kw": cprint("\n", k, v[0], [id2keyword[w] for w in v[0]]) if k == "batch_candidates_kw": utters = [] for utter in v[0]: utters.append([id2keyword[w] for w in utter]) cprint("\n", k, v[0], utters) if k == "batch_context_concepts": if len(v[0][0]) > 0: utters = [] for utter in v[0]: utters.append([id2node[w] for w in utter]) cprint("\n", k, v[0], utters) if k == "batch_candidates_concepts": utters = [] for utter in v[0]: utters.append([id2node[w] for w in utter]) cprint("\n", k, v[0], utters) if k == "batch_context_for_keyword_prediction": utters = [] for utter in v[0]: utters.append([id2word[w] for w in utter]) cprint("\n", k, v[0], utters) if k == "batch_context_concepts_for_keyword_prediction": cprint("\n", k, v[0], [id2node[w] for w in v[0]]) model.train() train_loss, (_, _, _) = run_epoch(train_batches, model, optimizer, training=True, device=device, fp16=fp16, amp=amp, \ kw_model=kw_model, keyword_mask_matrix=keyword_mask_matrix, step_scheduler=scheduler, keywordid2wordid=keywordid2wordid, \ CN_hopk_edge_index=CN_hopk_edge_index) model.eval() valid_loss, (valid_precision, valid_recall, valid_MRR) = run_epoch(valid_batches, model, optimizer, training=False, device=device, \ kw_model=kw_model, keyword_mask_matrix=keyword_mask_matrix, keywordid2wordid=keywordid2wordid, CN_hopk_edge_index=CN_hopk_edge_index) # scheduler.step() cprint( "Config id: {0}, Epoch {1}: train loss: {2:.4f}, valid loss: {3:.4f}, valid precision: {4}, valid recall: {5}, valid MRR: {6}" .format(config_id, epoch + 1, train_loss, valid_loss, valid_precision, valid_recall, valid_MRR)) if scheduler is not None: cprint("Current learning rate: ", scheduler.get_last_lr()) epoch_train_losses.append(train_loss) epoch_valid_losses.append(valid_loss) epoch_valid_precisions.append(valid_precision) epoch_valid_recalls.append(valid_recall) epoch_valid_MRRs.append(valid_MRR) if save_model_path != "": if epoch == 0: for k, v in model.state_dict().items(): best_model_statedict[k] = v.cpu() else: if epoch_valid_recalls[-1][0] == max( [recall1 for recall1, _, _ in epoch_valid_recalls]): for k, v in model.state_dict().items(): best_model_statedict[k] = v.cpu() # early stopping if len(epoch_valid_recalls) >= 3 and epoch_valid_recalls[-1][ 0] < epoch_valid_recalls[-2][0] and epoch_valid_recalls[-2][ 0] < epoch_valid_recalls[-3][0]: break config.pop("seed") config.pop("config_id") metrics["config"] = config metrics["score"] = max([recall[0] for recall in epoch_valid_recalls]) metrics["epoch"] = np.argmax([recall[0] for recall in epoch_valid_recalls]).item() metrics["recall"] = epoch_valid_recalls[metrics["epoch"]] metrics["MRR"] = epoch_valid_MRRs[metrics["epoch"]] metrics["precision"] = epoch_valid_precisions[metrics["epoch"]] if save_model_path and seed == 1: cprint("Saving model to ", save_model_path) best_model_statedict["word2id"] = word2id best_model_statedict["model_kwargs"] = model_kwargs torch.save(best_model_statedict, save_model_path) return metrics
class NaverNerTrainer: def __init__(self, train_data_loader, model, learning_rate, warmup_step, adam_ep, adam_beta1, adam_beta2, weight_decay): self.train_data_loader = train_data_loader self.device = self.get_device() self.model = self.set_model(model) self.optimizer = Adam(params=self.model.parameters(), lr=learning_rate, betas=(adam_beta1, adam_beta2), eps=adam_ep, weight_decay=weight_decay) self.schedule = LambdaLR( optimizer=self.optimizer, lr_lambda=lambda lr_step: (lr_step / warmup_step) * learning_rate \ if lr_step < warmup_step else learning_rate ) self.loss_fn = nn.NLLLoss(ignore_index=0) def get_device(self): if torch.cuda.is_available(): device = "cuda:0" else: device = "cpu" print("device :", device) return device def set_model(self, model): if self.device == "cuda:0": model = model.to(self.device) if torch.cuda.device_count() > 1: return nn.parallel.DistributedDataParallel(model) return model def train(self, epoch, trainable=True, train_verbose_step=100): for i, data in enumerate(self.train_data_loader): data = {k: v.to(self.device) for k, v in data.items()} lr = self.schedule.get_last_lr()[0] loss, result = self.model.forward(data["input"], data["label"]) if trainable: self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.schedule.step() if i == 0 or i % train_verbose_step == 0 or i == len( self.train_data_loader) - 1: print({ "epoch": epoch, "step": i, "lr": lr, "loss": loss.item() }) data = {k: v.to("cpu") for k, v in data.itmes()} def save(self, epoch, path): torch.save(self.model.cpu(), path) self.model.to(self.device) print("Epoch :", epoch, "Save :", path)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, default='yv3a1_agg_dev') parser.add_argument('--train_set', type=str, default='HBMWR_mot_train') parser.add_argument('--val_set', type=str, default='Lab1_mot') parser.add_argument('--super_batchsize', type=int, default=32) parser.add_argument('--initial_imgsize', type=int, default=None) parser.add_argument('--optimizer', type=str, default='SGDMR') parser.add_argument('--optim_params', type=str, default='all') parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument('--warmup', type=int, default=1000) parser.add_argument('--checkpoint', type=str, default='rapid_H1MW1024_Mar11_4000_.pth') parser.add_argument('--print_interval', type=int, default=20) parser.add_argument('--eval_interval', type=int, default=200) parser.add_argument('--checkpoint_interval', type=int, default=2000) parser.add_argument('--demo_interval', type=int, default=100) parser.add_argument('--demo_images', type=str, default='fisheye') parser.add_argument('--debug_mode', type=str, default=None) args = parser.parse_args() assert torch.cuda.is_available() print('Initialing model...') model, global_cfg = name_to_model(args.model) # -------------------------- settings --------------------------- if args.debug_mode == 'overfit': raise NotImplementedError() print(f'Running debug mode: {args.debug_mode}...') # overfitting on one or a few images global_cfg['train.img_sizes'] = [640] global_cfg['train.initial_imgsize'] = 640 global_cfg['test.preprocessing'] = 'resize_pad_square' target_size = 640 global_cfg['train.data_augmentation'] = None enable_multiscale = False batch_size = 1 subdivision = 1 num_cpu = 0 warmup_iter = 40 elif args.debug_mode == 'local': print(f'Running debug mode: {args.debug_mode}...') # train on local laptop with a small resolution and batch size TRAIN_RESOLUTIONS = [384, 512] AUTO_BATCHSIZE = {'384': 4, '512': 2} initial_size = TRAIN_RESOLUTIONS[-1] global_cfg['train.initial_imgsize'] = initial_size batch_size = 2 seq_len = global_cfg['train.sequence_length'] super_batchsize = args.super_batchsize subdivision = int(np.ceil(super_batchsize / batch_size / seq_len)) # data augmentation setting enable_multiscale = True num_cpu = 0 warmup_iter = args.warmup # testing setting target_size = global_cfg.get('test.default_input_size', None) elif args.debug_mode == None: print(f'Debug mode disabled.') # normal training AUTO_BATCHSIZE = global_cfg['train.imgsize_to_batch_size'] TRAIN_RESOLUTIONS = global_cfg['train.img_sizes'] if args.initial_imgsize is not None: initial_size = args.initial_imgsize assert initial_size in TRAIN_RESOLUTIONS else: initial_size = TRAIN_RESOLUTIONS[-1] global_cfg['train.initial_imgsize'] = initial_size batch_size = AUTO_BATCHSIZE[str(initial_size)] seq_len = global_cfg['train.sequence_length'] super_batchsize = args.super_batchsize subdivision = int(np.ceil(super_batchsize / batch_size / seq_len)) # data augmentation setting enable_multiscale = True assert 'train.imgsize_to_batch_size' in global_cfg print( 'Auto-batchsize enabled. Automatically selecting the batch size.') num_cpu = 4 warmup_iter = args.warmup # testing setting target_size = global_cfg.get('test.default_input_size', None) else: raise Exception('Unknown debug mode') job_name = f'{args.model}_{args.train_set}_{args.lr}' # Prepare model pnum = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f'Number of trainable parameters of {args.model} =', pnum) model = model.cuda() model.train() # Training set and validation set setting print(f'Initializing training set {args.train_set}...') global_cfg['train.dataset_name'] = args.train_set dataset = get_trainingset(global_cfg) dataset.to_iterator(batch_size=batch_size, num_workers=num_cpu, pin_memory=True) print(f'Initializing validation set {args.val_set}...') eval_info, validation_func = get_valset(args.val_set) start_iter = -1 if args.checkpoint: print("Loading checkpoint...", args.checkpoint) weights_path = os.path.join(f'{PROJECT_ROOT}/weights', args.checkpoint) previous_state = torch.load(weights_path) if 'input' in global_cfg['model.agg.hidden_state_names']: for k in list(previous_state['model'].keys()): if 'netlist.0' in k: previous_state['model'].pop(k) try: model.load_state_dict(previous_state['model']) except: print('Cannot load weights. Trying to set strict=False...') model.load_state_dict(previous_state['model'], strict=False) print('Successfully loaded part of the weights.') start_iter = previous_state.get('iter', start_iter) print(f'Start from iteration: {start_iter}') print('Initializing tensorboard SummaryWriter...') if args.debug_mode: logger = SummaryWriter(f'{PROJECT_ROOT}logs/debug/{job_name}') else: logger = SummaryWriter(f'{PROJECT_ROOT}logs/{job_name}') print(f'Initializing optimizer with lr: {args.lr}') # set weight decay only on conv.weight params = [] if args.optim_params == 'all': for key, value in model.named_parameters(): decay = global_cfg[ 'train.sgd.weight_decay'] if 'conv' in key else 0.0 params += [{'params': value, 'weight_decay': decay}] elif args.optim_params == 'fix_backbone': for key, value in model.fpn.named_parameters(): decay = global_cfg[ 'train.sgd.weight_decay'] if 'conv' in key else 0.0 params += [{'params': value, 'weight_decay': decay}] for key, value in model.agg.named_parameters(): decay = global_cfg[ 'train.sgd.weight_decay'] if 'conv' in key else 0.0 params += [{'params': value, 'weight_decay': decay}] for key, value in model.rpn.named_parameters(): decay = global_cfg[ 'train.sgd.weight_decay'] if 'conv' in key else 0.0 params += [{'params': value, 'weight_decay': decay}] elif args.optim_params == 'agg_only': for key, value in model.agg.named_parameters(): decay = global_cfg[ 'train.sgd.weight_decay'] if 'conv' in key else 0.0 params += [{'params': value, 'weight_decay': decay}] else: raise NotImplementedError() pnum = sum(p['params'].numel() for p in params if p['params'].requires_grad) print(f'Number of training parameters =', pnum) # Initialize optimizer optimizer = optim.get_optimizer(name=args.optimizer, params=params, lr=args.lr, cfg=global_cfg) if args.checkpoint and args.optimizer in previous_state: try: optimizer.load_state_dict(previous_state[args.optimizer]) except: print( 'Failed loading optimizer state. Initialize optimizer from scratch.' ) start_iter = -1 # Learning rate scheduler lr_schedule_func = lambda x: lr_warmup(x, warm_up=warmup_iter) from torch.optim.lr_scheduler import LambdaLR scheduler = LambdaLR(optimizer, lr_schedule_func, last_epoch=start_iter) print('Start training...') today = timer.today() start_time = timer.tic() for iter_i in range(start_iter, 1000000): # evaluation if iter_i > 0 and iter_i % args.eval_interval == 0: # if iter_i % args.eval_interval == 0: if args.debug_mode != 'overfit': model.eval() model.clear_hidden_state() with timer.contexttimer() as t0: model_eval = api.Detector(model_and_cfg=(model, global_cfg)) dts = model_eval.eval_predict_vod(eval_info, input_size=target_size, conf_thres=global_cfg.get( 'test.ap_conf_thres', 0.005)) eval_str, ap, ap50, ap75 = validation_func(dts) del model_eval s = f'\nCurrent time: [ {timer.now()} ], iteration: [ {iter_i} ]\n\n' s += eval_str + '\n\n' s += f'Validation elapsed time: [ {t0.time_str} ]' print(s) logger.add_text('Validation summary', s, iter_i) logger.add_scalar('Validation AP[IoU=0.5]', ap50, iter_i) logger.add_scalar('Validation AP[IoU=0.75]', ap75, iter_i) logger.add_scalar('Validation AP[IoU=0.5:0.95]', ap, iter_i) model.train() torch.cuda.reset_max_memory_allocated(0) seq_len = dataset.seq_len # subdivision loop optimizer.zero_grad() for _ in range(subdivision): seq_imgs, seq_labels, seq_flags, img_ids = dataset.get_next() assert len(seq_imgs) == len(seq_labels) == len(seq_flags) # visualize the clip for debugging if False: for b in range(batch_size): for _im, _lab in zip(seq_imgs, seq_labels): _im = image_ops.img_tensor_to_np( _im[b], model.input_format, 'BGR_uint8') _lab[b].draw_on_np(_im) cv2.imshow('', _im) cv2.waitKey(500) model.clear_hidden_state() for imgs, labels, is_start in zip(seq_imgs, seq_labels, seq_flags): imgs = imgs.cuda() loss = model(imgs, is_start, labels) assert not torch.isnan(loss) loss.backward() for p in model.parameters(): if p.grad is not None: p.grad.data.mul_(1.0 / subdivision / seq_len) optimizer.step() scheduler.step() # logging if iter_i % args.print_interval == 0: sec_used = timer.tic() - start_time time_used = timer.sec2str(sec_used) _ai = sec_used / (iter_i + 1 - start_iter) avg_iter = timer.sec2str(_ai) avg_img = _ai / batch_size / subdivision / seq_len avg_100img = timer.sec2str(avg_img * 100) avg_epoch = timer.sec2str(avg_img * 118287) print(f'\nTotal time: {time_used}, 100 imgs: {avg_100img}, ', f'iter: {avg_iter}, COCO epoch: {avg_epoch}') print( f'effective batch size = {batch_size} * {subdivision} * {seq_len}' ) max_cuda = torch.cuda.max_memory_allocated(0) / 1024 / 1024 / 1024 print(f'Max GPU memory usage: {max_cuda:.3f} GB') current_lr = scheduler.get_last_lr()[0] print(f'[Iteration {iter_i}] [learning rate {current_lr:.3g}]', f'[Total loss {loss:.2f}] [img size {dataset.img_size}]') print(model.loss_str) # random resizing if enable_multiscale and iter_i > 0 and (iter_i % 10 == 0): # # Randomly pick a input resolution imgsize = np.random.choice(TRAIN_RESOLUTIONS) # Set the image size in datasets batch_size = AUTO_BATCHSIZE[str(imgsize)] subdivision = int(np.ceil(super_batchsize / batch_size / seq_len)) dataset.img_size = imgsize dataset.to_iterator(batch_size=batch_size, num_workers=num_cpu, pin_memory=True) # save checkpoint if iter_i > 0 and (iter_i % args.checkpoint_interval == 0): state_dict = { 'iter': iter_i, 'model': model.state_dict(), args.optimizer: optimizer.state_dict(), } save_path = f'{PROJECT_ROOT}/weights/{job_name}_{today}_{iter_i}.pth' torch.save(state_dict, save_path) # save detection if iter_i > 0 and iter_i % args.demo_interval == 0: if args.debug_mode != 'overfit': model.eval() model_eval = api.Detector(model_and_cfg=(model, global_cfg)) demo_images_dir = f'{PROJECT_ROOT}/images/{args.demo_images}' for imname in os.listdir(): if not imname.endswith('.jpg'): continue impath = os.path.join(demo_images_dir, imname) model_eval.model.clear_hidden_state() np_img = model_eval.detect_one(img_path=impath, return_img=True, conf_thres=0.3, input_size=target_size) if args.debug_mode is not None: cv2_im = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR) log_dir = f'{PROJECT_ROOT}/logs/{args.model}_debug/' if not os.path.exists(log_dir): os.mkdir(log_dir) s = os.path.join(log_dir, f'{imname[:-4]}_iter{iter_i}.jpg') cv2.imwrite(s, cv2_im) else: if min(np_img.shape[:2]) > 512: _h, _w = np_img.shape[:2] _r = 512 / min(_h, _w) np_img = cv2.resize(np_img, (int(_w * _r), int(_h * _r))) logger.add_image(impath, np_img, iter_i, dataformats='HWC') model.train()
class Trainer(object): def __init__(self, train_loader, eval_loader, model, optimizer, loss_fn, train_seq_length, train_sample_size=None, gradient_clipping=0.05): self.train_loader = train_loader self.eval_loader = eval_loader self.model = model.cuda() if torch.cuda.is_available() else model # self.model.weight_init() self.optimizer = optimizer self.loss_fn = loss_fn self.train_seq_length = train_seq_length self.train_sample_size = 9999999999 if train_sample_size is None else train_sample_size self.gradient_clipping = gradient_clipping scheduler_lambda = cos_decay_with_warmup(warmup=5, T=200, start_val=0.00001) self.scheduler = LambdaLR(self.optimizer, lr_lambda=scheduler_lambda) self.epoch = 0 self.patience_counter = 0 self.best_loss = None self.epoch_logger = EpochLogger() def train_epoch(self): self.model.train() nan_counter = 0 for step, (features_d, features_s, target, _) in enumerate(self.train_loader): if torch.cuda.is_available(): features_d = features_d.cuda(non_blocking=True) features_s = features_s.cuda(non_blocking=True) target = target.cuda(non_blocking=True) #if torch.isnan(features_d).any(): # raise ValueError( # 'NaN in dynamic features during training, training stopped.') #if torch.isnan(features_s).any(): # raise ValueError( # 'NaN in static features during training, training stopped.') #if torch.isnan(target).any(): # raise ValueError( # 'NaN in target during training, training stopped.') pred = self.model(features_d, features_s) loss = self.loss_fn( pred[:, self.train_loader.dataset.num_warmup_steps:], target[:, self.train_loader.dataset.num_warmup_steps:]) self.optimizer.zero_grad() loss.backward() if torch.isnan(loss): # This is a debugging feature, if NaNs occur, possible a bug or unstable # model. nan_counter += 1 if nan_counter > 9: raise ValueError( 'Training loss was NaN >5 times, training stopped.') warn( f'Training loss was NaN {nan_counter} time{"" if nan_counter==1 else "s"} ' 'in a row, stopping after >9.') continue torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clipping) self.optimizer.step() self.epoch_logger.log('loss', 'train', loss.item()) del loss if step > self.train_sample_size: break self.epoch += 1 self.scheduler.step() stats = self.epoch_logger.get_summary() return { 'epoch': self.epoch, 'lr': self.scheduler.get_last_lr()[0], **stats } @torch.no_grad() def eval_epoch(self): self.model.eval() for step, (features_d, features_s, target, _) in enumerate(self.eval_loader): if torch.cuda.is_available(): features_d = features_d.cuda(non_blocking=True) features_s = features_s.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # if torch.isnan(features).any(): # raise ValueError( # 'NaN in features in evaluation, training stopped.') # if torch.isnan(targets).any(): # raise ValueError( # 'NaN in targets in evaluation, training stopped.') pred = self.model(features_d, features_s) loss = self.loss_fn( pred[:, self.eval_loader.dataset.num_warmup_steps:], target[:, self.eval_loader.dataset.num_warmup_steps:]) if torch.isnan(loss): raise ValueError('Eval loss is NaN, training stopped.') self.epoch_logger.log('loss', 'eval', loss.item()) if step > self.train_sample_size: break stats = self.epoch_logger.get_summary() perc_improved = self.early_stopping(stats['loss_eval']) return { **stats, 'patience_counter': self.patience_counter, 'perc_improved': perc_improved, 'best_loss': self.best_loss } @torch.no_grad() def predict(self, target_dir, use_training_set=False): self.model.eval() if use_training_set: print('\npredicting training set\n') else: print('\npredicting test set\n') print('Prediction saved to: ', target_dir) if use_training_set: data_loader = self.train_loader else: data_loader = self.eval_loader xr_var = data_loader.dataset.get_empty_xr() pred_array = np.zeros(xr_var.shape, dtype=np.float32) obs_array = np.zeros(xr_var.shape, dtype=np.float32) pred_array.fill(np.nan) obs_array.fill(np.nan) for step, (features_d, features_s, target, (lat, lon)) in enumerate(data_loader): print_progress( np.min(((step + 1) * data_loader.batch_size, len(data_loader.dataset))), len(data_loader.dataset), 'predicting') if torch.cuda.is_available(): features_d = features_d.cuda(non_blocking=True) features_s = features_s.cuda(non_blocking=True) target = target.cuda(non_blocking=True) pred = self.model(features_d, features_s) pred = pred[:, data_loader.dataset.num_warmup_steps:] target = target[:, data_loader.dataset.num_warmup_steps:] pred = self.unstandardize_target(pred) target = self.unstandardize_target(target) loss = self.loss_fn(pred, target) lat = lat.numpy() lon = lon.numpy() pred_array[:, lat, lon] = pred.cpu().numpy().T obs_array[:, lat, lon] = target.cpu().numpy().T self.epoch_logger.log('loss', 'test', loss.item()) print('\nWriting to file...') pred = xr.Dataset({ 'mod': xr.DataArray(pred_array, coords=[xr_var.time, xr_var.lat, xr_var.lon]), 'obs': xr.DataArray(obs_array, coords=[xr_var.time, xr_var.lat, xr_var.lon]) }) pred.obs.attrs = xr_var.attrs pred.obs.attrs = xr_var.attrs pred.attrs = { 'created': datetime.date.today().strftime('%b %d %Y'), 'contact': '[email protected], [email protected]', 'description': 'LSTM emulation of physical process model (Koirala et al. (2017))', 'var': xr_var.name, 'long_name': xr_var.attrs['long_name'] } pred_space_optim = pred.chunk({'lat': -1, 'lon': -1, 'time': 15}) pred_time_optim = pred.chunk({'lat': 15, 'lon': 15, 'time': -1}) if use_training_set: file_name_ending = '_trainset' else: file_name_ending = '' pred_space_optim.to_zarr( os.path.join(target_dir, f'pred_so{file_name_ending}.zarr')) pred_time_optim.to_zarr( os.path.join(target_dir, f'pred_to{file_name_ending}.zarr')) print('Done.') stats = self.epoch_logger.get_summary() return {**stats} def early_stopping(self, loss): if self.best_loss is not None: perc_improved = 100 * (1 - loss / self.best_loss) if perc_improved < 0.01: self.patience_counter += 1 else: self.patience_counter = 0 if loss < self.best_loss: self.best_loss = loss else: self.best_loss = loss self.perc_improved = perc_improved = 0 return perc_improved def save(self, checkpoint: str) -> None: """Saves the model at the provided checkpoint. Parameters ---------- checkpoint_dir Path to target checkpoint file. ¨ Returns ---------- checkpoint """ torch.save( { 'epoch': self.epoch, 'patience_counter': self.patience_counter, 'best_loss': self.best_loss, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict() }, checkpoint) return checkpoint def restore(self, checkpoint: str) -> None: """Restores the model from a provided checkpoint. Parameters ---------- filename Path to target checkpoint file. """ checkpoint = torch.load(checkpoint) self.model.load_state_dict(checkpoint['model_state_dict']) if torch.cuda.is_available(): self.model.to_device('cuda') self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.epoch = checkpoint['epoch'] self.patience_counter = checkpoint['patience_counter'] self.best_loss = checkpoint['best_loss'] def unstandardize_target(self, target): return self.train_loader.dataset.unstandardize( target, self.train_loader.dataset.target_var)
def run(n_epoch): train_set = Concat([ StateTransitionsDataset( hdf5_file="c_swm/data/blocks-{}-{}-{}_all.h5".format( OBJS, STACKS, 0), n_obj=OBJS + STACKS, remove_bg=REMOVE_BG, max_n_obj=8) for OBJS in [1, 2, 3, 4] ]) print("Training Examples: {}".format(len(train_set))) assert len(train_set) % TRAIN_BZ == 0 # assert len(test_set) % TEST_BZ == 0 train_loader = DataLoader(train_set, batch_size=TRAIN_BZ, shuffle=True) # test_loader = DataLoader(test_set, batch_size=TEST_BZ, shuffle=True) vae = FoSae().to(device) vae.load_state_dict( torch.load("fosae/model_{}/{}.pth".format(PREFIX, FOSAE_MODEL_NAME), map_location='cpu')) print("fosae/model_{}/{}.pth Loaded".format(PREFIX, FOSAE_MODEL_NAME)) vae.eval() new_data_set = get_new_dataset(train_loader, vae) del vae train_loader = DataLoader(dataset=new_data_set, batch_size=TRAIN_BZ, num_workers=4) test_loader = DataLoader(dataset=new_data_set, batch_size=TEST_BZ, num_workers=4) action_model = FoSae_Action().to(device) try: action_model.load_state_dict( torch.load("fosae/model_{}/{}.pth".format(PREFIX, ACTION_MODEL_NAME), map_location='cpu')) print("Action Model Loaded") except: print("Action Model Loaded Fail") pass optimizer = Adam(action_model.parameters(), lr=1e-3, betas=(0.9, 0.99)) # optimizer = SGD(action_model.parameters(), lr=1e-3) scheculer = LambdaLR(optimizer, lambda e: 1 if e < 100 else 0.1) best_loss = float('inf') for e in range(n_epoch): temp = np.maximum(TEMP_BEGIN * np.exp(-ANNEAL_RATE * e), TEMP_MIN) print("Epoch: {}, Temperature: {}, Lr: {}".format( e, temp, scheculer.get_last_lr())) sys.stdout.flush() train_loss = epoch_routine(train_loader, action_model, temp, optimizer) print('====> Epoch: {} Average train loss: {:.4f}'.format( e, train_loss)) test_loss = epoch_routine(test_loader, action_model, temp) print( '====> Epoch: {} Average test loss: {:.4f}, Best Test loss: {:.4f}' .format(e, test_loss, best_loss)) if test_loss < best_loss: print("Save Model") torch.save( action_model.state_dict(), "fosae/model_{}/{}.pth".format(PREFIX, ACTION_MODEL_NAME)) best_loss = test_loss scheculer.step()
def train(rank, args, model, model_t, train_dataset_qa, test_dataset_qa, scale_tune): """ Train the model """ global train_count train_count += 1 world_size = 1 if rank < 0 else torch.distributed.get_world_size() if rank in [-1, 0]: printlog("Train model",train_count) printlog(model) per_gpu_train_batch_size = args.per_gpu_train_batch_size train_batch_size = per_gpu_train_batch_size * world_size gradient_accumulation_steps = args.total_train_batch_size // train_batch_size num_train_epochs = args.num_train_epochs if scale_tune: gradient_accumulation_steps = 1 num_train_epochs = 1 if rank < 0: #single process take all samples sampler = RandomSampler(train_dataset_qa) dataloader = DataLoader(train_dataset_qa, sampler=sampler, batch_size=train_batch_size, num_workers=4) else: #special sampler that divide samples beween processes sampler = torch.utils.data.distributed.DistributedSampler(train_dataset_qa, rank=rank) dataloader = DataLoader(train_dataset_qa, sampler=sampler, batch_size=per_gpu_train_batch_size) steps_total = int(len(dataloader) // gradient_accumulation_steps * num_train_epochs) # Prepare optimizer and schedule freeze_list = args.freeze_list.split(',') if args.freeze_list else [] named_params = [] for n, p in model.named_parameters(): if n.lower()!="none" and any(fn in n for fn in freeze_list): if rank in [-1, 0]: logger.warning("rank {} {} param is frozen and excluded from tune".format(rank,n)) continue named_params.append( (n, p) ) # split parameters to scale and the rest named_params_scale = [(n, p) for n, p in named_params if '.scale' in n] named_params_rest = [(n, p) for n, p in named_params if '.scale' not in n] if scale_tune: #keep only scale parameters named_params = named_params_scale named_params_rest = [] groups = [] if named_params_scale: groups.append({'params': [p for n, p in named_params_scale], 'lr': 0.01}) if named_params_rest: groups.append({'params': [p for n, p in named_params_rest], 'lr': args.learning_rate}) optimizer = AdamW( groups, eps=1e-08, lr=args.learning_rate, weight_decay=0) def lr_lambda(current_step): p = float(current_step) / float(steps_total) return 1 - p scheduler = LambdaLR(optimizer, lr_lambda) if rank in [-1, 0]: for n,p in named_params: printlog('param for tune',n) printlog("scale_tune", scale_tune ) printlog("dataset size", len(train_dataset_qa) ) printlog("epoches", num_train_epochs ) printlog("per_gpu_train_batch_size", per_gpu_train_batch_size ) printlog("n_gpu", args.n_gpu ) printlog("world_size", world_size ) printlog("gradient_accumulation_steps", gradient_accumulation_steps ) printlog("total train batch size", train_batch_size * gradient_accumulation_steps ) printlog("steps_total",steps_total ) global_step = 0 model.zero_grad() indicators = collections.defaultdict(list) softplus = torch.nn.Softplus() loss_cfg = dict([t.split(':') for t in args.loss_cfg.split(',')]) if args.loss_cfg else dict() for epoch in range(math.ceil(num_train_epochs)): indicators = collections.defaultdict(list) model.train() set_output_hidden_states(rank, model, (model_t is not None)) utils.sync_models(rank, model) if model_t is not None: set_output_hidden_states(rank, model_t, True) model_t.train() if rank > -1: #set epoch to make different samples division betwen process for different epoches sampler.set_epoch(epoch) for step, batch in enumerate(dataloader): epoch_fp = epoch + step/len(dataloader) if epoch_fp > num_train_epochs: break epoch_fp = epoch + step/len(dataloader) losses = [] inputs = get_inputs(batch, args.device) targets = get_targets(batch, args.device) outputs = model(**inputs, **targets, output_hidden_states=(model_t is not None)) losses.append(outputs[0]) outputs = outputs[1:] if model_t is not None: with torch.no_grad(): outputs_t = model_t(**inputs, output_hidden_states=True) hidden_t = outputs_t[2] assert isinstance(hidden_t, (tuple,list)), "hidden states output is not detected right" assert len(hidden_t) == model_t.config.num_hidden_layers+1, "hidden states output is not detected right" if args.kd_weight>0: # Calculate knowladge distilation loss kd_losses = [] for logit_s,logit_t in zip(outputs[0:2],outputs_t[0:2]): T = 1 prob_t = torch.nn.functional.softmax(logit_t.detach() / T, dim=1) logprob_s = torch.nn.functional.log_softmax(logit_s / T, dim=1) kd_losses.append( -(logprob_s * prob_t).mean() * (T * T * prob_t.shape[1]) ) losses.append(args.kd_weight*sum(kd_losses)/len(kd_losses)) hidden_s = outputs[2] assert isinstance(hidden_s, (tuple,list)), "hidden states output is not detected right" assert len(hidden_s) == model.config.num_hidden_layers+1, "hidden states output is not detected right" def align_and_loss_outputs(out_s, out_t): if len(out_s) != len(out_t): #the student and teacher outputs are not aligned. try to find teacher output for each student output n_s, n_t = len(out_s), len(out_t) out_t = [out_t[(i*(n_t-1))//(n_s-1)] for i in range(n_s)] assert len(out_s) == len(out_t), "can not align number of outputs between student and teacher" assert all(s[0] == s[1] for s in zip(out_s[0].shape, out_t[0].shape)), "output shapes for student and teacher are not the same" return [(s - t.detach()).pow(2).mean() for s,t in zip(out_s, out_t)] sw_losses = align_and_loss_outputs(hidden_s,hidden_t) losses.extend([args.supervision_weight*l for l in sw_losses]) #average over batch losses = [l.mean() for l in losses] l = sum(losses)/len(losses) indicators['loss'].append(l.item()) indicators['ll'].append([lll.item() for lll in losses]) (l/gradient_accumulation_steps).backward() del l if (step + 1) % gradient_accumulation_steps == 0: global_step += 1 utils.sync_grads(rank, named_params, report_no_grad_params=(global_step==1)) torch.nn.utils.clip_grad_norm_([p for n, p in named_params], 1) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() if global_step % 50 == 0: # Log metrics wall_time = epoch + step / len(dataloader) lrp = " ".join(['{:.2f}'.format(t) for t in np.log10(scheduler.get_last_lr())]) str_out = "{} ep {:.2f} lrp {}".format(train_count, epoch_fp, lrp) for k,v in indicators.items(): v = np.array(v) if len(v.shape)==1: v = v[:,None] if rank>-1: #sync indicators vt = torch.tensor(v).to(args.device) torch.distributed.all_reduce(vt, op=torch.distributed.ReduceOp.SUM) v = vt.cpu().numpy() / float(world_size) str_out += " {} {}".format(k," ".join(["{:.3f}".format(t) for t in v.mean(0)])) if 'time_last' in locals(): #estimate processing times dt_iter = (time.time() - time_last) / len(indicators['loss']) dt_ep = dt_iter * len(dataloader) str_out += " it {:.1f}s".format(dt_iter) str_out += " ep {:.1f}m".format(dt_ep / (60)) str_out += " eta {:.1f}h".format(dt_ep * (num_train_epochs - epoch_fp) / (60 * 60)) time_last = time.time() indicators = collections.defaultdict(list) if rank in [-1, 0]: logger.info(str_out) if rank in [-1, 0]: check_point_name = 'checkpoint-{:02}'.format(train_count) check_point_name = check_point_name + '-{:02}'.format(epoch + 1) model.eval() set_output_hidden_states(rank, model, False) result_s = evaluate(args, model, test_dataset_qa) for k,v in result_s.items(): logger.info("{} {} {}".format(check_point_name, k, result_s[k])) if rank>-1: torch.distributed.barrier()
def train(self, model, train_loader, val_loader=None, num_epochs=10, log_nth=0): """ Train a given model with the provided data. Inputs: - model: model object initialized from a torch.nn.Module - train_loader: train data in torch.utils.data.DataLoader - val_loader: val data in torch.utils.data.DataLoader - num_epochs: total number of training epochs - log_nth: log training accuracy and loss every nth iteration """ # filter out frcnn if this is added to the module parameters = [ param for name, param in model.named_parameters() if 'frcnn' not in name] optim = self.optim(parameters, **self.optim_args) if self.lr_scheduler_lambda: scheduler = LambdaLR(optim, lr_lambda=self.lr_scheduler_lambda) else: scheduler = None self._reset_histories() iter_per_epoch = len(train_loader) self.logger('START TRAIN.') for epoch in range(num_epochs): self.logger(f"[*] EPOCH: {epoch}") model.train() # TRAINING if scheduler is not None and epoch: scheduler.step() self.writer.add_scalar('TRAIN/LR', scheduler.get_last_lr(), epoch + 1) self.logger(f"[*] LR: {scheduler.get_lr()}") now = time.time() for i, batch in enumerate(train_loader, 1): optim.zero_grad() losses = model.sum_losses(batch) losses['total_loss'].backward() optim.step() for k,v in losses.items(): if k not in self._losses.keys(): self._losses[k] = [] self._losses[k].append(v.data.cpu().numpy()) if log_nth and (i == 1 or i % log_nth == 0): next_now = time.time() self.logger( f'[Iteration {i + epoch * iter_per_epoch}/{iter_per_epoch * num_epochs}] ' f'{log_nth / (next_now - now):.1f} it/s') now = next_now for k, v in self._losses.items(): last_log_nth_losses = self._losses[k][-log_nth:] train_loss = np.mean(last_log_nth_losses) self.logger(f'{k}: {train_loss:.3f}') self.writer.add_scalar(f'TRAIN/{k}', train_loss, i + epoch * iter_per_epoch) # VALIDATION if val_loader: self.logger("[VAL:]") # ensure determinisic and comparble evaluation # random_states = { # 'numpy': np.random.get_state(), # 'torch': torch.random.get_rng_state(), # 'random': random.getstate()} # np.random.set_state(self.val_random_states['numpy']) # torch.random.set_rng_state(self.val_random_states['torch']) # random.setstate(self.val_random_states['random']) model.eval() val_losses = {} for i, batch in enumerate(val_loader): losses = model.sum_losses(batch) for k, v in losses.items(): if k not in val_losses.keys(): val_losses[k] = [] val_losses[k].append(v.data.cpu().numpy()) # np.random.set_state(random_states['numpy']) # torch.random.set_rng_state(random_states['torch']) # random.setstate(random_states['random']) for k, val_loss in val_losses.items(): val_loss = np.mean(val_loss) if k not in self._val_losses.keys(): self._val_losses[k] = [] if (k == 'prec_at_k' and self._val_losses[k] and val_loss > np.max(self._val_losses[k])): self.snapshot(model, f'best_val_{k}') self._val_losses[k].append(val_loss) self.logger(f'{k}: {val_loss:.3f}') self.writer.add_scalar(f'VAL/{k}', val_loss, epoch + 1) #blobs_val = data_layer_val.forward() #tracks_val = model.val_predict(blobs_val) #im = plot_tracks(blobs_val, tracks_val) #self.writer.add_image('val_tracks', im, (epoch+1) * iter_per_epoch) self.snapshot(model, 'latest') self._reset_histories() self.writer.close() self.logger('FINISH.')
class Learner: """Class used for training a Pytorch neural network. The class is initialized with a CNN model, a loss function, an optimizer and train and validation datasets. The main methods are: fit(epochs) : train the network for the given number of epochs pred(xb) : apply model to a batch save_state() and load_state() : save and load learner state get_history() : return the train and validation losses, performance metrics and learning rates for each epoch Parameters ---------- model : torch.nn The neural network to be trained loss_func : callable The function used for calculating the loss. Should have signature loss_func(input, target, weight=None, epoch=None). `input` has shape (batch size, num classes, height, width) and target has shape (batch size, height, width). `weight` can be used for weighting the loss for each pixel (same shape as `target`) and `epoch` is the current training epoch. optimizer : torch.optim Optimizer used for updating the parameters. train_dl : torch.Dataset Dataloader used for training valid_dl : torch.Dataset Dataloader used for validation scheduler : torch.optim.lr_scheduler Scheduler used for updating the learning rate of the optimizer perf_funcs: dict Dict of functions to be used for measuring performance. Each key is a string containing the name of the performance metric and respective values are functions with signature f(input, target) containing the prediction of the model and the ground truth. Shapes of input and target are the same as in `loss_func` main_perf_func : string Performance metric used for checking if the model has improved. At the end of each epoch, if this metric is larger than any other previously recorded, the parameters of the model are saved checkpoint_file : string File to save the model when the perfomance have improved. Also used as default file for saving the model when function save_state is called scheduler_step_epoch : bool If True, the scheduler will call step() after every epoch. If False, step() will be called after every batch. callbacks : list of callable List of callback functions to call on the validation data after each epoch. Functions must have signature callback(val_batch, val_label_batch, model_output, epoch), where val_batch and val_label_batch are and image and label batch, model_output is the output of the model for the batch and epoch is the current epoch. device : torch.device Device used for training verbose : bool If True, prints information regarding the model performance after each epoch to the standard output . TODO: Model should be moved to `device` before constructing the optimizer. Only solution is to receive an optimizer class instead of an instance? """ def __init__(self, model, loss_func, optimizer, train_dl, valid_dl, scheduler=None, perf_funcs=None, main_perf_func='loss', checkpoint_file='./learner.tar', scheduler_step_epoch=True, callbacks=None, device=None, verbose=True): if scheduler is None: scheduler = LambdaLR(optimizer, lambda x: 1) # Fixed learning rate if perf_funcs is None: perf_funcs = {} if callbacks is None: callbacks = [] if device is None: if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') self.model = model self.loss_func = loss_func self.optimizer = optimizer self.train_dl = train_dl self.valid_dl = valid_dl self.scheduler = scheduler self.scheduler_init_state = scheduler self.perf_funcs = perf_funcs self.main_perf_func = main_perf_func self.checkpoint_file = checkpoint_file self.scheduler_step_epoch = scheduler_step_epoch self.callbacks = callbacks self.device = device self.verbose = verbose self.train_loss_history = [] self.valid_loss_history = [] perf_funcs_history = {} for k, v in perf_funcs.items(): perf_funcs_history[k] = [] self.perf_funcs_history = perf_funcs_history #nb, nc, h, w = self.get_output_shape() self.lr_history = [] self.epoch = 0 self.checkpoint = {} # Will store best model found self.best_score = None def fit(self, epochs, lr=None): """Train model for the given number of epochs. Each epoch consists in updating the weights for one pass in the training set and measuring loss and performance metrics for one pass in the validation set. Parameters ---------- epochs : int Number of epochs for training lr : float If a learning rate is passed, this new value is used for training indepently of the learning rate used when instantiating the class. Note that in this case learning rate schedulers are ignored. """ if lr is not None: # Fix the learning rate self.scheduler = LambdaLR(self.optimizer, lambda x: 1) # Fixed learning rate for pg in self.optimizer.param_groups: pg['lr'] = lr self.model.to(self.device) if self.verbose: self._print_epoch_info_header() for epoch in range(epochs): self._train_one_epoch() self._validate() if (self.scheduler is not None) and self.scheduler_step_epoch: self.lr_history.append(self.scheduler.get_last_lr()) self.scheduler.step() if self.verbose: self._print_epoch_info() self._check_if_better_score() self.epoch += 1 def _train_one_epoch(self): """Train model for one epoch.""" self.model.train() train_loss = 0. for item_collection in self.train_dl: #print(f'a: {torch.cuda.memory_allocated(device=self.device)/1024**3}') loss, _, _ = self._apply_model_to_batch(*item_collection) #print(f'b: {torch.cuda.memory_allocated(device=self.device)/1024**3}') self.optimizer.zero_grad() loss.backward() self.optimizer.step() #print(f'c: {torch.cuda.memory_allocated(device=self.device)/1024**3}') if (self.scheduler is not None) and (not self.scheduler_step_epoch): self.lr_history.append(self.scheduler.get_last_lr()) self.scheduler.step() with torch.no_grad(): train_loss += loss.item() self.train_loss_history.append(train_loss / len(self.train_dl)) def _validate(self): """Validate the model for one epoch.""" self.model.eval() valid_loss = 0. valid_perf = dict( zip(self.perf_funcs.keys(), [0.] * len(self.perf_funcs))) with torch.no_grad(): for item_collection in self.valid_dl: loss, predb, yb = self._apply_model_to_batch(*item_collection) valid_loss += loss.item() perfs = self._apply_perf_funcs(predb, yb) for key in perfs: valid_perf[key] += perfs[key] self.valid_loss_history.append(valid_loss / len(self.valid_dl)) for idx, (func_name, perf_func) in enumerate(self.perf_funcs.items()): self.perf_funcs_history[func_name].append( valid_perf[func_name] / len(self.valid_dl)) for cb in self.callbacks: cb.on_epoch_end(item_collection[0], item_collection[1], predb, self.epoch) def _apply_model_to_batch(self, xb, yb, wb=None): """Given an input and target batch, and optionaly a loss weights batch, apply the model to the data and calculates loss. Parameters ---------- xb : torch.Tensor Input data yb : torch.Tensor Target data wb : torch.Tensor Weights for each pixel Returns ------- loss : torch.float The calculated loss predb : torch.Tensor The predictions of the model. yb : torch.Tensor Target data converted to long and on the correct self.evice. """ device = self.device xb, yb = xb.to(device, torch.float32), yb.to(device, torch.long) predb = self.model(xb) if wb is None: loss = self.loss_func(predb, yb) else: wb = wb.to(device, torch.float32) loss = self.loss_func(predb, yb, wb) return loss, predb, yb def _print_epoch_info_header(self): """Print table header shown during training""" print_str = f'{"Epoch":<7}{"Train loss":>15}{"Valid loss":>15}' for func_name in self.perf_funcs: print_str += f'{func_name:>15}' print(print_str) def _print_epoch_info(self): """Print training and validation loss and perfomance metrics calculated for the current epoch.""" print_str = f'{self.epoch:5}{self.train_loss_history[-1]:17.3f}{self.valid_loss_history[-1]:15.3f}' for func_name in self.perf_funcs: perf_func_h = self.perf_funcs_history[func_name] print_str += f'{perf_func_h[-1]:15.3f}' print(print_str) def _apply_perf_funcs(self, predb, yb): """Apply each performance metric function to the data. Parameters ---------- predb : torch.Tensor The model predictions yb : torch.Tensor The target Returns ------- valid_perf : dict Dictionary containing the values calculated for each function. Each key is the name of a function in self.perf_funcs. """ valid_perf = {} for idx, (func_name, perf_func) in enumerate(self.perf_funcs.items()): valid_perf[func_name] = perf_func(predb, yb) return valid_perf def _check_if_better_score(self): """Check if the value of the main performance function has improved. If True, the model is saved in the file given by self.checkpoint_file.""" score_improved = False prev_score = self.best_score if self.main_perf_func == 'loss': score = self.valid_loss_history[-1] if (prev_score is None) or (score < prev_score): score_improved = True else: score = self.perf_funcs_history[self.main_perf_func][-1] if (prev_score is None) or (score > prev_score): score_improved = True if score_improved: #print(f'Score improved from {prev_score} to {score} checkpoint saved') self.best_score = score self.update_checkpoint() self.save_state(True) def get_state_dict(self): """Returns dictionary containing all relevant information about this class. Returns ------- state_dict : dict Dictionary containing relevant class attributes """ state_dict = { 'model_state': self.model.state_dict(), 'optimizer_state': self.optimizer.state_dict(), 'scheduler_state': self.scheduler.state_dict(), 'epoch': self.epoch, 'best_score': self.best_score, 'model': str(self.model), 'train_loss_history': self.train_loss_history, 'valid_loss_history': self.valid_loss_history, 'lr_history': self.lr_history, 'perf_funcs_history': self.perf_funcs_history } return state_dict def update_checkpoint(self): """Updates chekpoint of the model using current parameters.""" self.checkpoint = self.get_state_dict() def save_state(self, checkpoint=False, filename=None): """Saves all the relevant information about this class. Parameters ---------- checkpoint : bool If True, saves the parameters associated with the best model found during training, that is, the model providing the largest value of function perf_funcs[main_perf_func]. If False, saves the current parameters of the model. filename : string Filename to save the information. If None, it is given by self.checkpoint_file """ if filename is None: filename = self.checkpoint_file if checkpoint: torch.save(self.checkpoint, filename) else: torch.save(self.get_state_dict(), filename) def load_state(self, filename=None): """Loads all the relevant information about this class from a file. Attributes of the class are updated with the information read. If you just need the trained model for making new predictions, it is possible to do: checkpoint = torch.load(filename) model = checkpoint['model_state'] pred = model(img) instead of using this function. Parameters ---------- filename : string Filename to load the information. If None, it is given by self.checkpoint_file """ if filename is None: filename = self.checkpoint_file checkpoint = torch.load(filename) self.model.load_state_dict(checkpoint['model_state']) self.optimizer.load_state_dict(checkpoint['optimizer_state']) self.scheduler.load_state_dict(checkpoint['scheduler_state']) self.epoch = checkpoint['epoch'] self.train_loss_history = checkpoint['train_loss_history'] self.valid_loss_history = checkpoint['valid_loss_history'] self.lr_history = checkpoint['lr_history'] self.perf_funcs_history = checkpoint['perf_funcs_history'] self.best_score = checkpoint['best_score'] self.checkpoint = checkpoint def save_model_dict(self, filename='model.pt'): """Save the parameters of the model.""" torch.save(self.model.state_dict(), filename) def load_model_dict(self, filename='model.pt'): """Load the parameters of the model.""" self.model.load_state_dict(torch.load(filename)) def save_model(self, filename='model.pickle'): """Save the model as a pickle file.""" torch.save(filename) def load_model(self, filename='model.pickle'): """Load a model from a pickle file.""" self.model = torch.load(filename) def save_history(self, filename, sep=';'): """Save the loss and performance metrics history to a file.""" train_loss_history = self.train_loss_history valid_loss_history = self.valid_loss_history perf_funcs_history = self.perf_funcs_history header = f'Epoch{sep}Train loss{sep}Valid loss' for func_name in perf_funcs_history: header += f'{sep}{func_name}' with open(filename, 'w') as fd: fd.write(header + '\n') for epoch in range(len(train_loss_history)): line_str = f'{epoch+1}{sep}{self.train_loss_history[epoch]:.5f}{sep}{valid_loss_history[epoch]:.5f}' for func_name, perf_func_h in perf_funcs_history.items(): line_str += f'{sep}{perf_func_h[epoch]:.5f}' fd.write(line_str + '\n') def pred(self, xb, yb=None, return_classes=False): """Apply model to a batch, model parameters are not updated. If `yb` is provided, also returns the performance metrics of the prediction for the given target (functions in `self.perf_funcs`). The possible returned values of this function are: if yb is None: if return_classes: Predicted classes else: Output of the model else: if return_classes: (Predicted classes, performance values) else: (Output of the model, performance values) Parameters ---------- xb : torch.Tensor Input of the model. Must have shape (batch size, channels, height, width) yb : torch.Tensor The target. Must have shape (batch size, height, width) return_classes : bool If True, returns classes instead of probabilities. Also returns performance metrics for the prediction Returns ------- predb : torch.Tensor The predicted class probabilities. Only returned if `return_classes` is False bin_predb : torch.Tensor The predicted segmentation. Returned in place of `predb` if `return_classes` is True predb_perf : float Performance metrics calculated for functions self.perf_funcs. Only returned if `yb` is not None # TODO: Implement TTA augmentation """ self.model.to(self.device) self.model.eval() with torch.no_grad(): xb = xb.to(self.device, torch.float32) predb = self.model(xb).to('cpu') if return_classes: classes_predb = torch.argmax(predb, dim=1).to('cpu', torch.uint8) if yb is None: if return_classes: return classes_predb else: return predb else: predb_perf = self._apply_perf_funcs(predb, yb) if return_classes: return classes_predb, predb_perf else: return predb, predb_perf def test(self, test_dl): """Measure the performance of the model for a dataset. Calculated performance metrics are given by the functions in self.perf_funcs. Parameters ---------- test_dl : torch.Dataset The input dataset. Usually a dataset used for the testing phase. Returns ------- test_perf : dict Dictionary of calculated performance metrics with the same keys as self.perf_funcs """ test_perf = dict( zip(self.perf_funcs.keys(), [0.] * len(self.perf_funcs))) with torch.no_grad(): for xb, yb in test_dl: _, pred_perf = self.pred(xb, yb) for idx, (k, v) in enumerate(pred_perf.items()): test_perf[k] += v test_perf = {k: v / len(test_dl) for k, v in test_perf.items()} return test_perf def get_output_shape(self): """Calculate the output shape of the model from one of the items in the dataset. Returns ------- tuple Shape of the output from the model """ xb, *_ = next(iter(self.train_dl.dataset[0])) self.model.to(self.device) xb = xb.to(self.device) with torch.no_grad(): pred = self.model(xb) self.model.to('cpu') return pred.shape def get_history(self): """Return the recorded history for some parameters and evaluations of this learner. Returned values are: train_loss : the training loss valid_loss : the validation loss perf : performance metrics calculated for functions stored in self.perf_funcs lr : learning rate Returns ------- history : dict Dictionary keyed by the name of the property. """ history = { 'train_loss': self.train_loss_history, 'valid_loss': self.valid_loss_history, 'perf': self.perf_funcs_history, 'lr': self.lr_history } return history def reset_history(self): """Reset the history stats for this learner.""" self.train_loss_history = [] self.valid_loss_history = [] self.lr_history = [] perf_funcs_history = {} for k, v in self.perf_funcs.items(): perf_funcs_history[k] = [] self.perf_funcs_history = perf_funcs_history def reset_training(self, optimizer, scheduler): """Reset parameters and history for this learner. Notice that this does not reset the optimizer and scheduler parameters! You can use functions set_optimizer() and set_scheduler() or reset_scheduler() for that.""" self.model.reset_parameters( ) # Will probably not work for fastai model self.reset_history() def reset_scheduler(self): """Reset the scheduler to its initial state.""" self.scheduler.load_state_dict(self.scheduler_init_state) def set_optimizer(self, optimizer, *args, **kwargs): """Set or update the optimizer. Parameters ---------- optimizer : torch.optim.Optimizer New optimizer. Note that a class should be passed, not an instance. If you want to set a new optimizer instance, do my_learner.optimizer = optimizer """ self.optimizer = optimizer(self.model, *args, **kwargs) def set_scheduler(self, scheduler, *args, **kwargs): """Set or update the learning rate scheduler. Note that a class should be passed, not an instance. If you want to set a new scheduler instance, do my_learner.scheduler = scheduler Parameters ---------- scheduler : torch.optim.lr_scheduler._LRScheduler New scheduler """ self.scheduler = scheduler(self.optimizer, *args, **kwargs)
def train(train_data, val_data, model, optimizer, logger, saver, num_epochs, batch_size, grad_clip): """Train SAKT model. Arguments: train_data (list of tuples of torch Tensor) val_data (list of tuples of torch Tensor) model (torch Module) optimizer (torch optimizer) logger: wrapper for TensorboardX logger saver: wrapper for torch saving num_epochs (int): number of epochs to train for batch_size (int) grad_clip (float): max norm of the gradients """ criterion = nn.BCEWithLogitsLoss() metrics = Metrics() step = 0 if optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) scheduler = None elif optimizer == 'noam': optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) def noam(step: int): step = max(1, step) warmup_steps = 2000 scale = warmup_steps**0.5 * min(step**(-0.5), step * warmup_steps**(-1.5)) return scale scheduler = LambdaLR(optimizer=optimizer, lr_lambda=noam) else: raise NotImplementedError val_batches = prepare_batches(val_data, batch_size, randomize=False) for epoch in range(num_epochs): train_batches = prepare_batches(train_data, batch_size) # Training for item_inputs, skill_inputs, label_inputs, item_ids, skill_ids, labels in train_batches: item_inputs = item_inputs.cuda() skill_inputs = skill_inputs.cuda() label_inputs = label_inputs.cuda() item_ids = item_ids.cuda() skill_ids = skill_ids.cuda() preds = model(item_inputs, skill_inputs, label_inputs, item_ids, skill_ids) loss = compute_loss(preds, labels.cuda(), criterion) preds = torch.sigmoid(preds).detach().cpu() train_auc = compute_auc(preds, labels) model.zero_grad() loss.backward() clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() step += 1 metrics.store({'loss/train': loss.item()}) metrics.store({'auc/train': train_auc}) if scheduler is not None: metrics.store({'lr': scheduler.get_last_lr()[0]}) scheduler.step() # Logging if step % 20 == 0: logger.log_scalars(metrics.average(), step) # Validation model.eval() for item_inputs, skill_inputs, label_inputs, item_ids, skill_ids, labels in val_batches: item_inputs = item_inputs.cuda() skill_inputs = skill_inputs.cuda() label_inputs = label_inputs.cuda() item_ids = item_ids.cuda() skill_ids = skill_ids.cuda() with torch.no_grad(): preds = model(item_inputs, skill_inputs, label_inputs, item_ids, skill_ids) preds = torch.sigmoid(preds).cpu() val_auc = compute_auc(preds, labels) metrics.store({'auc/val': val_auc}) model.train() # Save model average_metrics = metrics.average() average_metrics['epoch'] = epoch logger.log_scalars(average_metrics, step) stop = saver.save(average_metrics['auc/val'], model, epoch, average_metrics) if stop: break
def main(): parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, default='yolov3_80') parser.add_argument('--train_set', type=str, default='wheat1') parser.add_argument('--val_set', type=str, default='wheat1') parser.add_argument('--super_batchsize', type=int, default=32) parser.add_argument('--initial_imgsize', type=int, default=None) parser.add_argument('--optimizer', type=str, default='SGDMN') parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument('--warmup', type=int, default=1000) parser.add_argument('--checkpoint', type=str, default='') parser.add_argument('--print_interval', type=int, default=20) parser.add_argument('--eval_interval', type=int, default=100) parser.add_argument('--checkpoint_interval', type=int, default=2000) parser.add_argument('--demo_interval', type=int, default=20) parser.add_argument('--demo_images', type=str, default='wheat1') parser.add_argument('--debug_mode', type=str, default='overfit') args = parser.parse_args() assert torch.cuda.is_available() print('Initialing model...') model, global_cfg = name_to_model(args.model) # -------------------------- settings --------------------------- ap_conf_thres = global_cfg.get('test.ap_conf_thres', 0.005) if args.debug_mode == 'overfit': print(f'Running debug mode: {args.debug_mode}...') global_cfg['train.img_sizes'] = [640] global_cfg['train.initial_imgsize'] = 640 global_cfg['test.preprocessing'] = 'resize_pad_square' target_size = 640 global_cfg['train.data_augmentation'] = None enable_multiscale = False batch_size = 1 accumulate = 1 num_cpu = 0 warmup_iter = 40 elif args.debug_mode == 'local': print(f'Running debug mode: {args.debug_mode}...') # train on local laptop with a small resolution and batch size TRAIN_RESOLUTIONS = [384, 512] AUTO_BATCHSIZE = {'384': 4, '512': 2} initial_size = TRAIN_RESOLUTIONS[-1] global_cfg['train.initial_imgsize'] = initial_size batch_size = 2 super_batchsize = 8 accumulate = int(np.ceil(super_batchsize / batch_size)) # data augmentation setting enable_multiscale = True num_cpu = 0 warmup_iter = args.warmup # testing setting target_size = global_cfg.get('test.default_input_size', None) elif args.debug_mode == None: # training setting TRAIN_RESOLUTIONS = global_cfg['train.img_sizes'] AUTO_BATCHSIZE = global_cfg['train.imgsize_to_batch_size'] if args.initial_imgsize is not None: initial_size = args.initial_imgsize assert initial_size in TRAIN_RESOLUTIONS else: initial_size = TRAIN_RESOLUTIONS[-1] global_cfg['train.initial_imgsize'] = initial_size batch_size = AUTO_BATCHSIZE[str(initial_size)] super_batchsize = args.super_batchsize accumulate = int(np.ceil(super_batchsize / batch_size)) # data augmentation setting enable_multiscale = True assert 'train.imgsize_to_batch_size' in global_cfg print( 'Auto-batchsize enabled. Automatically selecting the batch size.') # optimizer setting num_cpu = 4 if global_cfg[ 'train.hard_example_mining'] != 'probability' else 0 warmup_iter = args.warmup # testing setting target_size = global_cfg.get('test.default_input_size', None) else: raise Exception('Unknown debug mode') job_name = f'{args.model}_{args.train_set}_{args.lr}' # Prepare model pnum = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f'Number of trainable parameters in {args.model}:', pnum) model = model.cuda() model.train() # Training set and validation set setting print(f'Initializing training set {args.train_set}...') global_cfg['train.dataset_name'] = args.train_set dataset = get_trainingset(global_cfg) dataset.to_iterator(batch_size=batch_size, shuffle=True, num_workers=num_cpu, pin_memory=True) print(f'Initializing validation set {args.val_set}...') eval_info, validation_func = get_valset(args.val_set) # validation function for hard example mining eval_func_ = eval_info['val_func'] if args.checkpoint: print("Loading checkpoint...", args.checkpoint) weights_path = f'{PROJECT_ROOT}/weights/{args.checkpoint}' previous_state = torch.load(weights_path) try: model.load_state_dict(previous_state['model']) except: print('Cannot load weights. Trying to set strict=False...') model.load_state_dict(previous_state['model'], strict=False) print('Successfully loaded part of the weights.') print('Initializing tensorboard SummaryWriter...') if args.debug_mode: logger = SummaryWriter(f'{PROJECT_ROOT}/logs/debug/{job_name}') else: logger = SummaryWriter(f'{PROJECT_ROOT}/logs/{job_name}') print(f'Initializing optimizer with lr: {args.lr}') # set weight decay only on conv.weight pg0, pg1, pg2 = [], [], [] # optimizer parameter groups for k, v in model.named_parameters(): if v.requires_grad: assert '.conv' in k or '.bn' in k if '.bias' in k: pg2.append(v) # biases elif '.conv' in k and '.weight' in k: pg1.append(v) # apply weight decay else: pg0.append(v) # all else optimizer = optim.get_optimizer(name=args.optimizer, params=pg0, lr=args.lr, global_cfg=global_cfg) decay = global_cfg['train.sgd.weight_decay'] optimizer.add_param_group({'params': pg1, 'weight_decay': decay}) optimizer.add_param_group({'params': pg2}) print( f'Optimizer groups: {len(pg1)} conv.weight, {len(pg2)} .bias, {len(pg0)} other' ) del pg0, pg1, pg2 start_iter = -1 if args.checkpoint and args.optimizer in previous_state: optimizer.load_state_dict(previous_state[args.optimizer]) start_iter = previous_state.get('iter', -2) + 1 print(f'Start from iteration: {start_iter}') # Learning rate scheduler lr_schedule_func = lambda x: optim.lr_warmup(x, warm_up=warmup_iter) from torch.optim.lr_scheduler import LambdaLR scheduler = LambdaLR(optimizer, lr_schedule_func, last_epoch=start_iter) print('Start training...') today = timer.today() start_time = timer.tic() for iter_i in range(start_iter, 1000000): # evaluation if iter_i > 0 and iter_i % args.eval_interval == 0: if not args.debug_mode: model.eval() with timer.contexttimer() as t0: model_eval = api.Detector(model_and_cfg=(model, global_cfg)) dts = model_eval.evaluation_predict( eval_info, input_size=target_size, conf_thres=ap_conf_thres, catIdx2id=dataset.catIdx2id) eval_str, ap, ap50, ap75 = validation_func(dts) del model_eval s = f'\nCurrent time: [ {timer.now()} ], iteration: [ {iter_i} ]\n\n' s += eval_str + '\n\n' s += f'Validation elapsed time: [ {t0.time_str} ]' print(s) logger.add_text('Validation summary', s, iter_i) logger.add_scalar('Validation AP[IoU=0.5]', ap50, iter_i) logger.add_scalar('Validation AP[IoU=0.75]', ap75, iter_i) logger.add_scalar('Validation AP[IoU=0.5:0.95]', ap, iter_i) model.train() torch.cuda.reset_max_memory_allocated(0) # accumulate loop optimizer.zero_grad() for _ in range(accumulate): batch = dataset.get_next() imgs, labels = batch['images'], batch['labels'] # for t_im, lbl in zip(imgs, labels): # np_im = image_ops.tensor_to_np(t_im, model.input_format, 'RGB_uint8') # lbl.draw_on_np(np_im, class_map='COCO', imshow=True) imgs = imgs.cuda() # try: dts, loss = model(imgs, labels) assert not torch.isnan(loss) loss.backward() # except RuntimeError as e: # if 'CUDA out of memory' in str(e): # print(f'CUDA out of memory at imgsize={dataset.img_size},', # f'batchsize={batch_size}') # raise e # if 'CUDA out of memory' in str(e): # print(f'CUDA out of memory at imgsize={dataset.img_size},', # f'batchsize={batch_size}') # print('Trying to reduce the batchsize at that image size...') # AUTO_BATCHSIZE[str(dataset.img_size)] -= 1 # dataset.to_iterator(batch_size=batch_size-1, shuffle=True, # num_workers=num_cpu, pin_memory=True) # else: # raise e # assert AUTO_BATCHSIZE[str(dataset.img_size)] == batch_size if global_cfg['train.hard_example_mining'] in {'probability'}: # calculate AP for each image idxs, img_ids, anns = batch['indices'], batch[ 'image_ids'], batch['anns'] for d, _idx, _id, g in zip(dts, idxs, img_ids, anns): d: ImageObjects d = d.post_process(conf_thres=ap_conf_thres, nms_thres=global_cfg['test.nms_thres']) d = d.to_json(img_id=_id, eval_type=eval_info['eval_type']) _, ap, ap50, ap75 = eval_func_(d, g, str_print=False) dataset.update_ap(_idx, ap) for p in model.parameters(): if p.grad is not None: p.grad.data.mul_(1.0 / accumulate) optimizer.step() scheduler.step() # logging if iter_i % args.print_interval == 0: sec_used = timer.tic() - start_time time_used = timer.sec2str(sec_used) _ai = sec_used / (iter_i + 1 - start_iter) avg_iter = timer.sec2str(_ai) avg_100img = timer.sec2str(_ai / batch_size / accumulate * 100) avg_epoch = timer.sec2str(_ai / batch_size / accumulate * 118287) print(f'\nTotal time: {time_used}, 100 imgs: {avg_100img}, ', f'iter: {avg_iter}, COCO epoch: {avg_epoch}') print(f'effective batch size = {batch_size} * {accumulate}') max_cuda = torch.cuda.max_memory_allocated(0) / 1024 / 1024 / 1024 print(f'Max GPU memory usage: {max_cuda:.3f} GB') current_lr = scheduler.get_last_lr()[0] print(f'[Iteration {iter_i}] [learning rate {current_lr:.3g}]', f'[Total loss {loss:.2f}] [img size {dataset.img_size}]') print(model.loss_str) # random resizing if enable_multiscale and iter_i > 0 and (iter_i % 10 == 0): # # Randomly pick a input resolution imgsize = np.random.choice(TRAIN_RESOLUTIONS) # Set the image size in datasets batch_size = AUTO_BATCHSIZE[str(imgsize)] accumulate = int(np.ceil(super_batchsize / batch_size)) dataset.img_size = imgsize dataset.to_iterator(batch_size=batch_size, shuffle=True, num_workers=num_cpu, pin_memory=True) # save checkpoint if iter_i > 0 and (iter_i % args.checkpoint_interval == 0): state_dict = { 'iter': iter_i, 'model': model.state_dict(), args.optimizer: optimizer.state_dict(), 'dataset': dataset.hem_state } save_path = f'{PROJECT_ROOT}/weights/{job_name}_{today}_{iter_i}.pth' torch.save(state_dict, save_path) # save detection if iter_i > 0 and iter_i % args.demo_interval == 0: if not args.debug_mode: model.eval() model_eval = api.Detector(model_and_cfg=(model, global_cfg)) demo_images_dir = f'{PROJECT_ROOT}/images/{args.demo_images}' for imname in os.listdir(demo_images_dir): # if not imname.endswith('.jpg'): continue impath = os.path.join(demo_images_dir, imname) np_img = model_eval.detect_one(img_path=impath, return_img=True, conf_thres=0.3, input_size=target_size) if args.debug_mode: cv2_im = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR) log_dir = f'{PROJECT_ROOT}/logs/{args.model}_debug/' if not os.path.exists(log_dir): os.mkdir(log_dir) s = os.path.join(log_dir, f'{imname[:-4]}_iter{iter_i}.jpg') cv2.imwrite(s, cv2_im) else: if min(np_img.shape[:2]) > 512: _h, _w = np_img.shape[:2] _r = 512 / min(_h, _w) np_img = cv2.resize(np_img, (int(_w * _r), int(_h * _r))) logger.add_image(impath, np_img, iter_i, dataformats='HWC') model.train()
def run_COPT(game, num_iter=5000, lr=0.5, seed=1234, biased=False, shuffling=False, lr_schedule=None, hamiltonian_coeff=10, **kwargs): config = Config(dict(mode="consensus_opt", num_iter=num_iter, lr=lr, seed=seed, hamiltonian_coeff=hamiltonian_coeff, shuffling=shuffling)) torch.manual_seed(seed) game.reset() sgd = optim.SGD(game.parameters(), lr=lr) if lr_schedule is not None: lr_schedule = SchedulerLR(lr_schedule) scheduler = LambdaLR(sgd, lr_schedule) else: scheduler = LambdaLR(sgd, lambda k: 1.) logger = defaultdict(list) if kwargs["output"] is not None: path = os.path.join(kwargs["output"], config.name, str(seed)) config["path"] = path if not os.path.exists(path): os.makedirs(os.path.join(path, "results")) config["name"] = config.name with open(os.path.join(path, "config.json"), "w") as f: json.dump(config, f, default=lambda x: "non-serializable") if shuffling: game.shuffle() n_samples = 0 start_time = time.time() for i in tqdm(range(num_iter)): index1 = game.sample() index2 = game.sample() if biased is True: grad1 = game.compute_grad(index1) grad2 = grad1 hamiltonian = compute_hamiltonian(grad1) n_samples += 1 elif biased == "copt": grad1 = game.compute_grad(torch.cat([index1, index2])) grad2 = grad1 hamiltonian = compute_hamiltonian(grad1) n_samples += 2 elif biased is False: grad1 = game.compute_grad(index1) grad2 = game.compute_grad(index2) hamiltonian = compute_hamiltonian(grad1) n_samples += 2 else: raise ValueError() grad_H = autograd.grad(hamiltonian, game.parameters()) for p, g1, g2, gH in zip(game.parameters(), grad1, grad2, grad_H): p.grad = 0.5*(g1+g2) + hamiltonian_coeff*gH sgd.step() scheduler.step() metrics = game.compute_metrics() for key, value in metrics.items(): logger[key].append(value) logger["lr"].append(scheduler.get_last_lr()) logger["num_samples"].append(n_samples) logger["time"].append(time.time()-start_time) if i % 10000 == 0: with open(os.path.join(path, "results.json"), "w") as f: json.dump(logger, f) return logger, config
def train(): train_data = data_rgbplane.SUNRGBD(transform=transforms.Compose([ data_rgbplane.scaleNorm(), data_rgbplane.RandomScale((1.0, 1.2)), data_rgbplane.RandomHSV((0.9, 1.1), (0.9, 1.1), (25, 25)), data_rgbplane.RandomCrop(image_h, image_w), data_rgbplane.RandomFlip(), data_rgbplane.ToTensor(), data_rgbplane.Normalize()]), phase_train=True, data_dir=args.data_dir) train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=False) num_train = len(train_data) if args.last_ckpt: model = model_rgbplane.model(pretrained=False) else: model = model_rgbplane.model(pretrained=True) if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") model = nn.DataParallel(model) CEL_weighted = utils.CrossEntropyLoss2d() model.train() model.to(device) CEL_weighted.to(device) optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) global_step = 0 if args.last_ckpt: global_step, args.start_epoch = load_ckpt(model, optimizer, args.last_ckpt, device) lr_decay_lambda = lambda epoch: args.lr_decay_rate ** (epoch // args.lr_epoch_per_decay) scheduler = LambdaLR(optimizer, lr_lambda=lr_decay_lambda) writer = SummaryWriter(args.summary_dir) for epoch in range(int(args.start_epoch), args.epochs): optimizer.step() scheduler.step(epoch) local_count = 0 last_count = 0 end_time = time.time() if epoch % args.save_epoch_freq == 0 and epoch != args.start_epoch: save_ckpt(args.ckpt_dir, model, optimizer, global_step, epoch, local_count, num_train) for batch_idx, sample in enumerate(train_loader): image = sample['image'].to(device) plane = sample['plane'].to(device) target_scales = [sample[s].to(device) for s in ['label', 'label2', 'label3', 'label4', 'label5']] optimizer.zero_grad() pred_scales = model(image, plane, args.checkpoint) loss = CEL_weighted(pred_scales, target_scales) loss.backward() optimizer.step() local_count += image.data.shape[0] global_step += 1 if global_step % args.print_freq == 0 or global_step == 1: time_inter = time.time() - end_time count_inter = local_count - last_count print_log(global_step, epoch, local_count, count_inter, num_train, loss, time_inter) end_time = time.time() for name, param in model.named_parameters(): writer.add_histogram(name, param.clone().cpu().data.numpy(), global_step, bins='doane') grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True) writer.add_image('image', grid_image, global_step) # RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 0 grid_image = make_grid(abs(plane[:, 0:3, :, :].clone().cpu().data), 3, normalize=True) writer.add_image('plane', grid_image, global_step) grid_image = make_grid(utils.color_label(torch.max(pred_scales[0][:3], 1)[1] + 1), 3, normalize=False, range=(0, 255)) writer.add_image('Predicted label', grid_image, global_step) grid_image = make_grid(utils.color_label(target_scales[0][:3]), 3, normalize=False, range=(0, 255)) writer.add_image('Groundtruth label', grid_image, global_step) writer.add_scalar('CrossEntropyLoss', loss.data, global_step=global_step) writer.add_scalar('Learning rate', scheduler.get_last_lr()[0], global_step=global_step) # it has to be get_last_lr here after pytorch 1.4(?). get_lr() is no longer current lr. last_count = local_count save_ckpt(args.ckpt_dir, model, optimizer, global_step, args.epochs, 0, num_train) print("Training completed ")
def run_training( source, target, dataset_root, net_name, da_method, max_iter, stop_iter, test_iter, logdir, run_name, gpu_id, load_workers, config, test_src: bool = False, use_tqdm: bool = True, kill_diverging: bool = False): dev = torch.device(f'cuda:{gpu_id}') if kill_diverging: assert test_src # Get config # Config arrives here (from BOHB or direct cli invocation) as a dictionary like # {'disc.dropout': 0.5, 'net.bottleneck_size_log': 9} # We separate it in something like # {'disc': {'dropout': 0.5'}, 'net': {'bottleneck_size_log': 9}} config = split_dict(config) # Disc args are not meaningful without DA if da_method != 'so': # Default disc args disc_args = { 'dropout': 0.5, 'num_fc_layers': 3, 'hidden_size_log': 10 } # Update with the ones coming from config (if any) disc_args.update(config.get('disc', {})) # Some args might be defined as log2. Replace them (bottleneck_size_log -> bottleneck_size) remove_log_hps(disc_args) # Print disc args print(f"Discriminator config: {disc_args}") # Very similar, but for the backbone net_args = { 'use_bottleneck': da_method != 'so', 'bottleneck_size_log': 9 } net_args.update(config.get('net', {})) remove_log_hps(net_args) print(f"Backbone config: {net_args}") # Now net_args and disc_args are ready to be passed to the network constructors as **kwargs :) bs, lr, wd = config['base']['bs'], config['base']['lr'], config['base']['wd'] # Load datasets and their number o classes dset_src_train, dset_src_test, dset_trg_train, dset_trg_test, num_classes = \ prepare_datasets(source, target, dataset_root) dload_src_train = DataLoader(dset_src_train, batch_size=bs, shuffle=True, num_workers=load_workers, drop_last=True) dload_src_test = DataLoader(dset_src_test, batch_size=bs, shuffle=False, num_workers=load_workers) dload_trg_train = DataLoader(dset_trg_train, batch_size=bs, shuffle=True, num_workers=load_workers, drop_last=True) dload_trg_test = DataLoader(dset_trg_test, batch_size=bs, shuffle=False, num_workers=load_workers) print(f"Source samples: {len(dset_src_train)}") print(f"Target samples: {len(dset_trg_train)}") print(f"Num classes: {num_classes}") # Build network base_network = resnet.ResNetFc( resnet_name=net_name, num_classes=num_classes, plug_position=7, **net_args ).to(dev) params = base_network.get_parameters(lr, wd) # Source only has no secondary branches if da_method != 'so': disc_classes = { # ( -> confusion matrix) 'alda': num_classes, # ( -> binary domain classifier) 'dann': 2 }[da_method] discriminator = resnet.Discriminator(in_feature=base_network.output_size(), num_classes=disc_classes, **disc_args).to(dev) params += discriminator.get_parameters(lr, wd) # Define optimizer optimizer = opt.SGD( params=params, lr=lr, momentum=0.9, weight_decay=wd, nesterov=True ) # Lr policy lr_schedule = LambdaLR(optimizer, lr_lambda=lambda it: (1 + 0.001 * it) ** (-0.75)) # Logger writer = Logger(logdir=logdir, run_name=run_name, use_tb=True, use_tqdm=use_tqdm) # Classification loss ce_loss = nn.CrossEntropyLoss() # Train loop len_train_source = len(dload_src_train) len_train_target = len(dload_trg_train) lambda_val = 0. # We store all the metrics here metrics = [] all_pseudolabels = [] with writer.progress(total=stop_iter, desc="Training") as pb: for i in range(stop_iter): if (i + 1) % test_iter == 0: print(f"Iteration: {i + 1} / {stop_iter} (max: {max_iter})") print("Testing...") base_network.train(False) # This dict contains metric-name -> value pairs for the current epoch new_metrics = {} if test_src: test_result, _, src_test_feats = test(dload_src_test, base_network, device=dev) # Print accuracy print("Source accuracy: {:.3f} %".format(test_result['accuracy'] * 100)) # Add the source metrics to the dict (with the source_ prefix) new_metrics.update({f'source_{k}': v for k, v in test_result.items()}) test_result, epoch_pseudolabels, _ = test(dload_trg_test, base_network, device=dev, source_feats=src_test_feats) all_pseudolabels.append(epoch_pseudolabels) print(f"Target accuracy: {test_result['accuracy'] * 100:.3f} %") writer.add_scalar('train/base_lr', lr_schedule.get_last_lr()[0], i) writer.add_scalar('train/lambda', lambda_val, i) new_metrics.update({f'target_{k}': v for k, v in test_result.items()}) # Add all the new metrics to tensorboard logs add_scalars(writer, new_metrics, global_step=i, prefix='test/') # Add a column with iteration number new_metrics.update({'iter': i}) # Concatenate to older epoch metrics metrics.append(new_metrics) # Kill this training if source loss goes too high if kill_diverging and new_metrics['source_class_loss'] > SOURCE_LOSS_THRESHOLD: if len(metrics) > 0 and new_metrics['source_class_loss'] > metrics[-1]['source_class_loss']: print(f"Increasing source_class_loss exceeds maximum allowed source loss ({new_metrics['source_class_loss']} > {SOURCE_LOSS_THRESHOLD})") break # Train one iteration base_network.train(True) if da_method != 'so': discriminator.train(True) optimizer.zero_grad() # Reset data loops if required if i % len_train_source == 0: iter_source = iter(dload_src_train) if i % len_train_target == 0: iter_target = iter(dload_trg_train) # Load source inputs_source, labels_source = iter_source.next() inputs_source, labels_source = map_to_device(dev, (inputs_source, labels_source)) # Compute source features and classification output outputs_source, features_source = base_network(inputs_source) # Classification loss classifier_loss = ce_loss(outputs_source, labels_source) # Actual DA part if da_method != 'so': # Load target samples without target labels inputs_target, _ = iter_target.next() inputs_target = inputs_target.to(dev) # Compute target features and classification output outputs_target, features_target = base_network(inputs_target) # Source and target features features = torch.cat((features_source, features_target), dim=0) # Source and target classification outputs (-> softmax) outputs = torch.cat((outputs_source, outputs_target), dim=0) softmax_out = nn.Softmax(dim=1)(outputs) # CORE if da_method == 'dann': p = float(i / max_iter) lambda_val = 2. / (1 + np.exp(-10 * p)) - 1 ad_out = discriminator(features, lambda_val) adv_loss = loss.DANN_loss(ad_out) transfer_loss = adv_loss if (i + 1) % test_iter == 0: print("Transfer loss: {:.3f}".format(transfer_loss.item())) elif da_method == 'alda': p = float(i / max_iter) lambda_val = 2. / (1 + np.exp(-10 * p)) - 1 ad_out = discriminator(features, lambda_val) adv_loss, reg_loss, correct_loss = loss.ALDA_loss(ad_out, labels_source, softmax_out, threshold=0.9) transfer_loss = adv_loss + lambda_val * correct_loss if (i + 1) % test_iter == 0: print("Transfer loss: {:.3f}, reg loss {:.3f}%".format(transfer_loss.item(), reg_loss.item())) # Backpropagate reg_loss only through the discriminator with base_network.freeze(): reg_loss.backward(retain_graph=True) # END CORE else: transfer_loss = 0 total_loss = classifier_loss + config['base']['weight_da'] * transfer_loss total_loss.backward() optimizer.step() lr_schedule.step() if (i + 1) % test_iter == 0 and da_method != 'so': writer.add_scalar('train/transfer_loss', transfer_loss.item(), i) pb.update(1) # Convert list of dicts to dataframe containing metrics metrics = pd.DataFrame(metrics) # Compute global-pseudolabel accuracy all_pseudolabels = np.array(all_pseudolabels) global_pseudolabels = compute_time_consistent_pseudolabels(all_pseudolabels, num_classes) pseudolabel_acc = np.equal(all_pseudolabels, global_pseudolabels).sum(axis=1) / global_pseudolabels.shape[0] # Add it to the metrics dataframe metrics['target_pseudolabels'] = pseudolabel_acc # Save the metrics with open(os.path.join(logdir, run_name, "metrics.pkl"), "wb") as fp: pickle.dump(metrics, fp) # Log global pseudolabel accuracy to tensorboard for i in range(len(all_pseudolabels)): writer.add_scalar('test/target_pseudolabels', float(pseudolabel_acc[i]), i * test_iter) return metrics
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if args.center_crop: train_transform = T.Compose([ ResizeImage(256), T.CenterCrop(224), T.RandomHorizontalFlip(), T.ToTensor(), normalize ]) else: train_transform = T.Compose([ ResizeImage(256), T.RandomResizedCrop(224), T.RandomHorizontalFlip(), T.ToTensor(), normalize ]) val_transform = T.Compose( [ResizeImage(256), T.CenterCrop(224), T.ToTensor(), normalize]) dataset = datasets.__dict__[args.data] train_source_dataset = dataset(root=args.root, task=args.source, download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.target, download=True, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.data == 'DomainNet': test_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) else: test_loader = val_loader train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = models.__dict__[args.arch](pretrained=True) num_classes = train_source_dataset.num_classes classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim).to(device) classifier_feature_dim = classifier.features_dim if args.randomized: domain_discri = DomainDiscriminator(args.randomized_dim, hidden_size=1024).to(device) else: domain_discri = DomainDiscriminator(classifier_feature_dim * num_classes, hidden_size=1024).to(device) all_parameters = classifier.get_parameters( ) + domain_discri.get_parameters() # define optimizer and lr scheduler optimizer = SGD(all_parameters, args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR( optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x))**(-args.lr_decay)) # define loss function domain_adv = ConditionalDomainAdversarialLoss( domain_discri, entropy_conditioning=args.entropy, num_classes=num_classes, features_dim=classifier_feature_dim, randomized=args.randomized, randomized_dim=args.randomized_dim).to(device) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = validate(test_loader, classifier, args) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): print("lr:", lr_scheduler.get_last_lr()[0]) # train for one epoch train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = validate(val_loader, classifier, args) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = validate(test_loader, classifier, args) print("test_acc1 = {:3.1f}".format(acc1)) logger.close()
def train(rank, args, model, model_t, train_dataset_qc, test_dataset_qc, fq_tune_only, model_controller): """ Train the model """ global train_count train_count += 1 world_size = 1 if rank < 0 else torch.distributed.get_world_size() if rank in [-1, 0]: printlog("Train model", train_count) printlog(model) q_dataset = train_dataset_qc.q_dataset per_gpu_train_batch_size = args.per_gpu_train_batch_size train_batch_size = per_gpu_train_batch_size * world_size if fq_tune_only: gradient_accumulation_steps = 1 num_train_epochs = 1 else: gradient_accumulation_steps = args.total_train_batch_size // train_batch_size num_train_epochs = args.num_train_epochs if rank < 0: #single process take all q_sampler = RandomSampler(q_dataset) q_dataloader = DataLoader(q_dataset, sampler=q_sampler, batch_size=train_batch_size, num_workers=4) else: #special sampler that divide samples between processes q_sampler = torch.utils.data.distributed.DistributedSampler(q_dataset, rank=rank) q_dataloader = DataLoader(q_dataset, sampler=q_sampler, batch_size=per_gpu_train_batch_size) steps_total = int( len(q_dataloader) // gradient_accumulation_steps * num_train_epochs) # Prepare optimizer and schedule named_params, groups = utils.make_param_groups( rank, model, args. freeze_list, #list or str with subnames to define frozen parameters args.learning_rate, #learning rate for no FQ parameters 0.01, # learning rate for FQ parameters fq_tune_only, #true if only FQ parameters will be optimized model_controller) optimizer = AdamW(groups, eps=1e-08, lr=args.learning_rate, weight_decay=0) def lr_lambda(current_step): p = float(current_step) / float(steps_total) return 1 - p scheduler = LambdaLR(optimizer, lr_lambda) if rank in [-1, 0]: for n, p in named_params: printlog('param for tune', n) printlog("fq_tune_only", fq_tune_only) printlog("dataset size", len(q_dataset)) printlog("epoches", num_train_epochs) printlog("per_gpu_train_batch_size", per_gpu_train_batch_size) printlog("n_gpu", args.n_gpu) printlog("world_size", world_size) printlog("gradient_accumulation_steps", gradient_accumulation_steps) printlog("total train batch size", train_batch_size * gradient_accumulation_steps) printlog("steps_total", steps_total) global_step = 1 model.zero_grad() indicators = collections.defaultdict(list) softplus = torch.nn.Softplus() loss_cfg = dict([t.split(':') for t in args.loss_cfg.split(',')]) hnm_hist = {} for epoch in range(math.ceil(num_train_epochs)): indicators = collections.defaultdict(list) model.train() if model_t: model_t.train() if rank > -1: #set epoch to make different samples division betwen process for different epoches q_sampler.set_epoch(epoch) utils.sync_models(rank, model) for step, q_batch in enumerate(q_dataloader): epoch_fp = epoch + step / len(q_dataloader) if epoch_fp > num_train_epochs: break losses = [] context_ids_pos = q_batch[3] q_inputs = get_inputs(q_batch, args.device) q_outputs = model(**q_inputs, output_hidden_states=(model_t is not None)) q_vec = q_outputs[0] #get positive embeddings c_batch = train_dataset_qc.c_dataset[context_ids_pos.detach().data] c_inputs = get_inputs(c_batch, args.device) c_outputs = model(**c_inputs, output_hidden_states=(model_t is not None)) c_vec_pos = c_outputs[0] if model_t is not None: q_emb_s, q_hidden_s = q_outputs c_emb_s, c_hidden_s = c_outputs with torch.no_grad(): q_emb_t, q_hidden_t = model_t(**q_inputs, output_hidden_states=True) c_emb_t, c_hidden_t = model_t(**c_inputs, output_hidden_states=True) def align_and_loss_outputs(out_s, out_t): if len(out_s) != len(out_t): #the student and teacher outputs are not aligned. try to find teacher output for each student output n_s, n_t = len(out_s), len(out_t) out_t = [ out_t[(i * (n_t - 1)) // (n_s - 1)] for i in range(n_s) ] assert len(out_s) == len( out_t ), "can not align number of outputs between student and teacher" assert all( s[0] == s[1] for s in zip(out_s[0].shape, out_t[0].shape) ), "output shapes for student and teacher are not the same" return [(s - t.detach()).pow(2).mean() for s, t in zip(out_s, out_t)] l_q = align_and_loss_outputs(q_hidden_s, q_hidden_t) l_c = align_and_loss_outputs(c_hidden_s, c_hidden_t) emb_loss = loss_cfg.get('emb_loss', '') if emb_loss == 'L2': l_q.append((q_emb_s - q_emb_t.detach()).pow(2).mean()) l_c.append((c_emb_s - c_emb_t.detach()).pow(2).mean()) elif emb_loss == 'L1': l_q.append((q_emb_s - q_emb_t.detach()).abs().mean()) l_c.append((c_emb_s - c_emb_t.detach()).abs().mean()) elif emb_loss.lower() not in ['', 'none', '0', 'disable']: raise Exception( 'emb_loss={} is unsupported'.format(emb_loss)) losses.extend([args.supervision_weight * l for l in l_c + l_q]) triplet_num = int(loss_cfg.get('triplet_num', 1)) if fq_tune_only: triplet_num = 0 if triplet_num > 0: #disable grad to select negatives with torch.no_grad(): hnm_scores = [] hnm_idxs = [] #check that current step has no HNM conext vector if global_step not in hnm_hist and args.hnm_num > 0: #generate the new one if world_size > 1 and (args.hnm_num % world_size) != 0: #aligh hnm_num per each replica hnm_plus = world_size - (args.hnm_num % world_size) args.hnm_num += hnm_plus logger.warning( "rank {} args.hnm_num increased by {} from {} to {} to be the same after division by {} replicas." .format(rank, hnm_plus, args.hnm_num - hnm_plus, args.hnm_num, world_size)) # generate random contexts to calc embedding context_ids_all = torch.randint( low=0, high=len(train_dataset_qc.c_dataset), size=[args.hnm_num]) if rank < 0: #single process take all context_ids = context_ids_all else: #broadcast one sigle indicies to all processes context_ids_all = context_ids_all.to(args.device) torch.distributed.broadcast(context_ids_all, 0) context_ids_all = context_ids_all.cpu() #each process take only small part to calc embedding s = ((rank + 0) * args.hnm_num) // world_size e = ((rank + 1) * args.hnm_num) // world_size context_ids = context_ids_all[s:e] batch_size = min(args.hnm_batch_size, context_ids.shape[0]) s, e = 0, batch_size c_outputs = [] while e > s: idx = context_ids.detach()[s:e] c_batch = train_dataset_qc.c_dataset[idx] inputs = get_inputs(c_batch, args.device) outputs = model(**inputs, output_hidden_states=False) c_outputs.append(outputs[0]) s, e = e, min(e + batch_size, context_ids.shape[0]) context_emb = torch.cat(c_outputs, dim=0) if rank < 0: # single process calculated all context_emb_all = context_emb else: context_emb_list = [ torch.zeros_like(context_emb) for _ in range(world_size) ] torch.distributed.all_gather( context_emb_list, context_emb) context_emb_all = torch.cat(context_emb_list, dim=0) hnm_hist[global_step] = (context_ids_all, context_emb_all) #check history size and crop the oldest one if len(hnm_hist) > args.hnm_hist_num: del hnm_hist[min(hnm_hist.keys())] #calc HNM scores for current question batch for hist_step, (c_idx, c_vec) in hnm_hist.items(): w = args.hnm_hist_alpha**(global_step - hist_step) t1 = q_vec[:, None, :] t2 = c_vec[None, :, :] d = (t1 - t2) score = -d.norm(2, dim=-1) score = score * w hnm_scores.append(score) hnm_idxs.append(c_idx) if hnm_scores: #choose the hardest negative if we have scores score = torch.cat(hnm_scores, dim=-1) idx = torch.cat(hnm_idxs, dim=-1) score = score.cpu() pos_mask = (context_ids_pos[:, None] == idx[None, :]).to( dtype=score.dtype, device=score.device) score = (1 - pos_mask) * score + pos_mask * score.min( ) #make positive context with small score to avoid chose it as hard neg hn_idx = score.argmax(dim=1, keepdim=True) context_ids_neg = idx[hn_idx] else: #just random selection in case of no scores for HNM size = (context_ids_pos.shape[0], 1) context_ids_neg = torch.randint( 0, len(train_dataset_qc.c_dataset) - 1, size) shift = (context_ids_neg >= context_ids_pos[:, None]) context_ids_neg = context_ids_neg + shift.to( dtype=context_ids_neg.dtype) d_pos = (q_vec - c_vec_pos).norm(2, dim=-1) # get negative embeddings and calc losses for neg_index in range(context_ids_neg.shape[1]): ids = context_ids_neg[:, neg_index] c_batch = train_dataset_qc.c_dataset[ids.detach()] inputs = get_inputs(c_batch, args.device) outputs = model(**inputs, output_hidden_states=False) c_vec_neg = outputs[0] for triplet_index in range(triplet_num): if triplet_index == 0: d_neg = (q_vec - c_vec_neg).norm(2, dim=-1) if triplet_index == 1: d_neg = (c_vec_pos - c_vec_neg).norm(2, dim=-1) d_diff = d_pos - d_neg indicators['dd' + str(triplet_index)].append( [v.mean().item() for v in (d_pos, d_neg, d_diff)]) l = softplus(d_diff) losses.append(l) del d_neg del d_pos #average over batch losses = [l.mean() for l in losses] l = sum(losses) / len(losses) (l / gradient_accumulation_steps).backward() indicators['loss'].append(l.item()) indicators['ll'].append([lll.item() for lll in losses]) #del losses del l if (step + 1) % gradient_accumulation_steps == 0: utils.sync_grads(rank, named_params, report_no_grad_params=(global_step == 1)) torch.nn.utils.clip_grad_norm_([p for n, p in named_params], 1) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if global_step % 10 == 0: # Log metrics wall_time = epoch + step / len(q_dataloader) lrp = [ '{:.2f}'.format(i) for i in np.log10(scheduler.get_last_lr()) ] str_out = "{} ep {:.2f} lrp {}".format( train_count, epoch_fp, " ".join(lrp)) for k, v in indicators.items(): v = np.array(v) if len(v.shape) == 1: v = v[:, None] if rank > -1: #sync indicators vt = torch.tensor(v).to(args.device) torch.distributed.all_reduce( vt, op=torch.distributed.ReduceOp.SUM) v = vt.cpu().numpy() / float(world_size) str_out += " {} {}".format( k, " ".join(["{:.3f}".format(t) for t in v.mean(0)])) if 'score' in locals(): str_out += " SS {}".format(list(score.shape)) if 'time_last' in locals(): dt_iter = (time.time() - time_last) / len( indicators['loss']) dt_ep = dt_iter * len(q_dataloader) str_out += " it {:.1f}s".format(dt_iter) str_out += " ep {:.1f}m".format(dt_ep / (60)) str_out += " eta {:.1f}h".format( dt_ep * (num_train_epochs - epoch_fp) / (60 * 60)) time_last = time.time() indicators = collections.defaultdict(list) if rank in [-1, 0]: logger.info(str_out) if rank in [-1, 0]: check_point_name = 'checkpoint-{:02}'.format(train_count) check_point_name = check_point_name + '-{:02}'.format(epoch + 1) result_s = evaluate(args, model.eval(), test_dataset_qc) for k, v in result_s.items(): logger.info("{} {} {}".format(check_point_name, k, v)) if rank > -1: torch.distributed.barrier()
def lr_range_test(self, data_loader, end_lr, num_iter=100, step_mode='exp', alpha=0.05, ax=None): # Since the test updates both model and optimizer we need to store # their initial states to restore them in the end previous_states = { 'model': deepcopy(self.model.state_dict()), 'optimizer': deepcopy(self.optimizer.state_dict()) } # Retrieves the learning rate set in the optimizer start_lr = self.optimizer.state_dict()['param_groups'][0]['lr'] # Builds a custom function and corresponding scheduler lr_fn = make_lr_fn(start_lr, end_lr, num_iter) scheduler = LambdaLR(self.optimizer, lr_lambda=lr_fn) # Variables for tracking results and iterations tracking = {'loss': [], 'lr': []} iteration = 0 # If there are more iterations than mini-batches in the data loader, # it will have to loop over it more than once while (iteration < num_iter): # That's the typical mini-batch inner loop for x_batch, y_batch in data_loader: x_batch = x_batch.to(self.device) y_batch = y_batch.to(self.device) # Step 1 yhat = self.model(x_batch) # Step 2 loss = self.loss_fn(yhat, y_batch) # Step 3 loss.backward() # Here we keep track of the losses (smoothed) # and the learning rates tracking['lr'].append(scheduler.get_last_lr()[0]) if iteration == 0: tracking['loss'].append(loss.item()) else: prev_loss = tracking['loss'][-1] smoothed_loss = alpha * loss.item() + (1 - alpha) * prev_loss tracking['loss'].append(smoothed_loss) iteration += 1 # Number of iterations reached if iteration == num_iter: break # Step 4 self.optimizer.step() scheduler.step() self.optimizer.zero_grad() # Restores the original states self.optimizer.load_state_dict(previous_states['optimizer']) self.model.load_state_dict(previous_states['model']) if ax is None: fig, ax = plt.subplots(1, 1, figsize=(6, 4)) else: fig = ax.get_figure() ax.plot(tracking['lr'], tracking['loss']) if step_mode == 'exp': ax.set_xscale('log') ax.set_xlabel('Learning Rate') ax.set_ylabel('Loss') fig.tight_layout() return tracking, fig
class TrainClass: def __init__( self, model, train_configs_inp, save_folder, final_save_name, snapshot_path, logger, ): self.model = model self.train_configs = train_configs_inp self.save_folder = save_folder self.final_save_name = final_save_name self.logger = logger self.device = torch.device("cuda") train_utils.print_model(self.model, self.logger) logger.info("Train params:\t%s\n", self.train_configs) self.logger.info( "TRAINING PARAMETERS:\t" "optimizer: adamax\t" "base_learning_rate = %.8f,\t" "grad_clip=%.2f\n", self.train_configs.base_learning_rate, self.train_configs.grad_clip, ) self.optimizer = torch.optim.Adamax( filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.train_configs.base_learning_rate, ) if snapshot_path: self._load_model(snapshot_path) lr_for_epochs = train_utils.get_lr_for_epochs( self.train_configs)[self.train_configs.start_epoch:] self.logger.info("LR for epochs : %s", lr_for_epochs) self.scheduler = LambdaLR( self.optimizer, lr_lambda=lambda epoch: (lr_for_epochs[epoch] / self.train_configs.base_learning_rate), ) def train(self, train_loader, eval_loader): for epoch in range(self.train_configs.start_epoch, self.train_configs.number_of_epochs): self.logger.info( "Training For Epoch: %d\tLearning rate = %.4f", epoch, self.scheduler.get_last_lr()[0], ) epoch_start_time = time.time() train_size = len(train_loader.dataset) total_loss, total_score = self._train_epoch(train_loader) # Update learning rate. Skip updating in the last iteration. if epoch != self.train_configs.number_of_epochs - 1: self.scheduler.step() total_loss /= train_size total_score = 100 * total_score / train_size eval_score = 0 self.logger.info( "epoch %d,\t" "train_size: %d,\t" "time: %.2f,\t" "train_loss: %.2f\t" "SCORE: %.4f\n\n", epoch, train_size, time.time() - epoch_start_time, total_loss, total_score, ) if epoch == self.train_configs.number_of_epochs - 1: self.logger.info("Saving model as %s", "final.pth") model_path = os.path.join(self.save_folder, "final") train_utils.save_model(model_path, self.model, self.optimizer, epoch, total_score) self._save_model_if_eligible(epoch, total_score) if (eval_loader and total_score > self.train_configs.save_score_threshold): self.model.train(False) self.logger.info("Threshold reached. Evaluating..") eval_score, _ = evaluate(self.model, eval_loader) self.model.train(True) self.logger.info("EVAL SCORE : %.4f\n\n", eval_score * 100) def _train_epoch(self, train_loader): total_loss = 0 total_score = 0 total_attention_loss = 0 for _, (image_features, _, question, labels) in enumerate( tqdm( train_loader, total=len(train_loader), position=0, leave=True, colour="blue", )): image_features = Variable(image_features).to(self.device) question = Variable(question).to(self.device) labels = Variable(labels).to(self.device) pred, v_att, _ = self.model(image_features, question) loss = loss_utils.classification_loss(pred, labels) # Clearing old gradients. self.optimizer.zero_grad() # Computes the gradient for the parameters. loss.backward() # Clips the norm of the overall gradient. Prevents exploding gradients. nn.utils.clip_grad_norm_(self.model.parameters(), self.train_configs.grad_clip) # Updates all the parameters based on the gradients. self.optimizer.step() total_loss += loss.data.item() * image_features.size(0) total_score += loss_utils.compute_score(pred, labels.data).sum() return total_loss, total_score def _load_model(self, snapshot_path): model_data = torch.load(snapshot_path) self.model.load_state_dict(model_data.get("model_state", model_data)) self.optimizer.load_state_dict( model_data.get("optimizer_state", model_data)) self.train_configs.start_epoch = model_data["epoch"] + 1 def _save_model_if_eligible(self, epoch, total_score): if total_score >= 75 and epoch % self.train_configs.save_step == 0: save_name = "model_epoch{0}_score_{1}.pth".format( epoch, int(total_score)) self.logger.info("Saving model as %s", save_name) model_path = os.path.join(self.save_folder, save_name) train_utils.save_model(model_path, self.model, self.optimizer, epoch, total_score)