def evaluate_epoch(engine): log_results(logger, 'train/epoch', engine.state, engine.state.epoch) state = evaluate_once(evaluator, iterator=iters['val']) log_results(logger, 'valid/epoch', state, engine.state.epoch) log_results_cmd('valid/epoch', state, engine.state.epoch) save_ckpt(args, engine.state.epoch, engine.state.metrics['loss'], model, vocab) evaluate_by_logic_level(args, model, iterator=iters['val'])
def save_epoch(engine): log_results(logger, 'pretrain/epoch', engine.state, engine.state.epoch) log_results_cmd(logger, 'pretrain/epoch', engine.state, engine.state.epoch) save_ckpt(args, engine.state.epoch, engine.state.metrics['loss'], model)
def run(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_worker = 2 n_epoch = args.epochs torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) # backward pass print('Load Train and Test Set') train_loader = DataLoader(MnistBags(target_number=args.target_number, min_target_count=args.min_target_count, mean_bag_length=args.mean_bag_length, var_bag_length=args.var_bag_length, scale=args.scale, num_bag=args.num_bags_train, seed=args.seed, train=True), batch_size=args.batchsize, shuffle=True, num_workers=n_worker, pin_memory=torch.cuda.is_available()) test_loader = DataLoader(MnistBags(target_number=args.target_number, min_target_count=args.min_target_count, mean_bag_length=args.mean_bag_length, var_bag_length=args.var_bag_length, scale=args.scale, num_bag=args.num_bags_test, seed=args.seed, train=False), batch_size=args.batchsize, shuffle=False, num_workers=n_worker, pin_memory=torch.cuda.is_available()) # resume checkpoint checkpoint = load_ckpt() if checkpoint: print('Resume training ...') start_epoch = checkpoint.epoch model = checkpoint.model else: print('Grand new training ...') start_epoch = 0 model = Attention() # put model to multiple GPUs if available if torch.cuda.device_count() > 1: print("Let's use ", torch.cuda.device_count(), " GPUs!") model = nn.DataParallel(model) model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=args.reg) if checkpoint: try: optimizer.load_state_dict(checkpoint.optimizer) except: print( '[WARNING] optimizer not restored from last checkpoint, continue without previous state' ) # free checkpoint reference del checkpoint log_dir = os.path.join('logs', args.logname) n_cv_epoch = 1 #2 with SummaryWriter(log_dir) as writer: print('\nTraining started ...') for epoch in range(start_epoch + 1, n_epoch + start_epoch + 1): # 1 base train(model, optimizer, train_loader, epoch, writer) if epoch % n_cv_epoch == 0: with torch.no_grad(): test(model, optimizer, test_loader, epoch, writer) save_ckpt(model, optimizer, epoch) print('\nTraining finished ...')