def make_optimizer(model, cfg): assert cfg.SOLVER.OPTIMIZER in [ 'Adam', 'SGD', 'Ranger', 'RangerQH', 'RangerALR' ], 'Nome optimizer non riconosciuto!' if cfg.SOLVER.OPTIMIZER == 'Adam': return torch.optim.Adam(model.parameters(), lr=cfg.SOLVER.LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY, betas=cfg.SOLVER.BETAS, amsgrad=cfg.SOLVER.AMSGRAD) elif cfg.SOLVER.OPTIMIZER == 'SGD': return torch.optim.SGD(model.parameters(), lr=cfg.SOLVER.LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY, nesterov=cfg.SOLVER.NESTEROS) elif cfg.SOLVER.OPTIMIZER == 'Ranger': return Ranger(model.parameters(), lr=cfg.SOLVER.LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) elif cfg.SOLVER.OPTIMIZER == 'RangerQH': return RangerQH(model.parameters(), lr=cfg.SOLVER.LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) elif cfg.SOLVER.OPTIMIZER == 'RangerALR': return RangerVA(model.parameters(), lr=cfg.SOLVER.LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY, amsgrad=cfg.SOLVER.AMSGRAD)
def train(params, n_epochs, verbose=True): # init interpolation model timestamp = int(time.time()) formatted_params = '_'.join(f'{k}={v}' for k, v in params.items()) torch.manual_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if FLAGS.filename == None: G = SepConvNetExtended(kl_size=params['kl_size'], kq_size=params['kq_size'], kl_d_size=params['kl_d_size'], kl_d_scale=params['kl_d_scale'], kq_d_scale=params['kq_d_scale'], kq_d_size=params['kq_d_size'], input_frames=params['input_size']) if params['pretrain'] in [1, 2]: print('LOADING L1') G.load_weights('l1') name = f'{timestamp}_seed_{FLAGS.seed}_{formatted_params}' G = torch.nn.DataParallel(G).cuda() # optimizer = torch.optim.Adamax(G.parameters(), lr=params['lr'], betas=(.9, .999)) if params['optimizer'] == 'ranger': optimizer = Ranger([{ 'params': [p for l, p in G.named_parameters() if 'moduleConv' not in l] }, { 'params': [p for l, p in G.named_parameters() if 'moduleConv' in l], 'lr': params['lr2'] }], lr=params['lr'], betas=(.95, .999)) elif params['optimizer'] == 'adamax': optimizer = torch.optim.Adamax([{ 'params': [p for l, p in G.named_parameters() if 'moduleConv' not in l] }, { 'params': [p for l, p in G.named_parameters() if 'moduleConv' in l], 'lr': params['lr2'] }], lr=params['lr'], betas=(.9, .999)) else: raise NotImplementedError() scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=60 - FLAGS.warmup, T_mult=1, eta_min=1e-5) start_epoch = 0 else: checkpoint = torch.load(FLAGS.filename) G = checkpoint['last_model'].cuda() start_epoch = checkpoint['epoch'] + 1 name = checkpoint['name'] optimizer = torch.optim.Adamax([{ 'params': [p for l, p in G.named_parameters() if 'moduleConv' not in l] }, { 'params': [p for l, p in G.named_parameters() if 'moduleConv' in l], 'lr': params['lr2'] }], lr=params['lr'], betas=(.9, .999)) optimizer.load_state_dict(checkpoint['optimizer']) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs - FLAGS.warmup, eta_min=1e-5, last_epoch=-1) for _ in range(start_epoch - FLAGS.warmup + 1): scheduler.step() print('SETTINGS:') print(params) print('NAME:') print(name) sys.stdout.flush() # loss_network = losses.LossNetwork(layers=[9,16,26]).cuda() #9, 16, 26 # Perc_loss = losses.PerceptualLoss(loss_network, include_input=True) torch.manual_seed(42) np.random.seed(42) random.seed(42) quadratic = params['input_size'] == 4 L1_loss = torch.nn.L1Loss().cuda() # Flow_loss = losses.FlowLoss(quadratic=quadratic).cuda() ds_train_lmd = dataloader.large_motion_dataset(quadratic=quadratic, cropped=True, fold='train', min_flow=6) ds_valid_lmd = dataloader.large_motion_dataset(quadratic=quadratic, cropped=True, fold='valid') ds_vimeo_train = dataloader.vimeo90k_dataset(fold='train', quadratic=quadratic) ds_vimeo_test = dataloader.vimeo90k_dataset(fold='test', quadratic=quadratic) # train, test_lmd = dataloader.split_data(ds_lmd, [.9, .1]) train_vimeo, valid_vimeo = dataloader.split_data(ds_vimeo_train, [.9, .1]) # torch.manual_seed(FLAGS.seed) # np.random.seed(FLAGS.seed) # random.seed(FLAGS.seed) train_settings = { 'flip_probs': FLAGS.flip_probs, 'normalize': True, 'crop_size': (FLAGS.crop_size, FLAGS.crop_size), 'jitter_prob': FLAGS.jitter_prob, 'random_rescale_prob': FLAGS.random_rescale_prob # 'rescale_distr':(.8, 1.2), } valid_settings = { 'flip_probs': 0, 'random_rescale_prob': 0, 'random_crop': False, 'normalize': True } train_lmd = dataloader.TransformedDataset(ds_train_lmd, **train_settings) valid_lmd = dataloader.TransformedDataset(ds_valid_lmd, **valid_settings) train_vimeo = dataloader.TransformedDataset(train_vimeo, **train_settings) valid_vimeo = dataloader.TransformedDataset(valid_vimeo, **valid_settings) test_vimeo = dataloader.TransformedDataset(ds_vimeo_test, **valid_settings) train_data = torch.utils.data.ConcatDataset([train_lmd, train_vimeo]) # displacement df = pd.read_csv(f'hardinstancesinfo/vimeo90k_test_flow.csv') test_disp = torch.utils.data.Subset( ds_vimeo_test, indices=df[df.mean_manh_flow >= df.quantile(.9).mean_manh_flow].index. tolist()) test_disp = dataloader.TransformedDataset(test_disp, **valid_settings) test_disp = torch.utils.data.DataLoader(test_disp, batch_size=4, pin_memory=True) # nonlinearity df = pd.read_csv(f'hardinstancesinfo/Vimeo90K_test.csv') test_nonlin = torch.utils.data.Subset( ds_vimeo_test, indices=df[ df.non_linearity >= df.quantile(.9).non_linearity].index.tolist()) test_nonlin = dataloader.TransformedDataset(test_nonlin, **valid_settings) test_nonlin = torch.utils.data.DataLoader(test_nonlin, batch_size=4, pin_memory=True) # create weights for train sampler df_vim = pd.read_csv(f'hardinstancesinfo/vimeo90k_train_flow.csv') weights_vim = df_vim[df_vim.index.isin( train_vimeo.dataset.indices)].mean_manh_flow.tolist() weights_lmd = ds_train_lmd.weights train_sampler = torch.utils.data.sampler.WeightedRandomSampler( weights_lmd + weights_vim, FLAGS.num_train_samples, replacement=False) train_dl = torch.utils.data.DataLoader(train_data, batch_size=FLAGS.batch_size, pin_memory=True, shuffle=False, sampler=train_sampler, num_workers=FLAGS.num_workers) valid_dl_vim = torch.utils.data.DataLoader(valid_vimeo, batch_size=4, pin_memory=True, num_workers=FLAGS.num_workers) valid_dl_lmd = torch.utils.data.DataLoader(valid_lmd, batch_size=4, pin_memory=True, num_workers=FLAGS.num_workers) test_dl_vim = torch.utils.data.DataLoader(test_vimeo, batch_size=4, pin_memory=True, num_workers=FLAGS.num_workers) # metrics writer = SummaryWriter(f'runs/final_exp/full_run_losses/{name}') results = ResultStore(writer=writer, metrics=['psnr', 'ssim', 'ie', 'L1_loss', 'lf'], folds=FOLDS) early_stopping_metric = 'L1_loss' early_stopping = EarlyStopping(results, patience=FLAGS.patience, metric=early_stopping_metric, fold='valid_vimeo') loss_network = losses.LossNetwork(layers=[26]).cuda() #9, 16, 26 Perc_loss = losses.PerceptualLoss(loss_network).cuda() def do_epoch(dataloader, fold, epoch, train=False): assert fold in FOLDS if verbose: pb = tqdm(desc=f'{fold} {epoch+1}/{n_epochs}', total=len(dataloader), leave=True, position=0) for i, (X, y) in enumerate(dataloader): X = X.cuda() y = y.cuda() y_hat = G(X) l1_loss = L1_loss(y_hat, y) feature_loss = Perc_loss(y_hat, y) lf_loss = l1_loss + feature_loss if train: optimizer.zero_grad() lf_loss.backward() optimizer.step() # compute metrics y_hat = (y_hat * 255).clamp(0, 255) y = (y * 255).clamp(0, 255) psnr = metrics.psnr(y_hat, y) ssim = metrics.ssim(y_hat, y) ie = metrics.interpolation_error(y_hat, y) results.store( fold, epoch, { 'L1_loss': l1_loss.item(), 'psnr': psnr, 'ssim': ssim, 'ie': ie, 'lf': lf_loss.item() }) if verbose: pb.update() # update tensorboard results.write_tensorboard(fold, epoch) sys.stdout.flush() start_time = time.time() for epoch in range(start_epoch, n_epochs): G.train() do_epoch(train_dl, 'train_fold', epoch, train=True) if epoch >= FLAGS.warmup - 1: scheduler.step() G.eval() with torch.no_grad(): do_epoch(valid_dl_vim, 'valid_vimeo', epoch) do_epoch(valid_dl_lmd, 'valid_lmd', epoch) if (early_stopping.stop() and epoch >= FLAGS.min_epochs ) or epoch % FLAGS.test_every == 0 or epoch + 1 == n_epochs: with torch.no_grad(): do_epoch(test_disp, 'test_disp', epoch) do_epoch(test_nonlin, 'test_nonlin', epoch) do_epoch(test_dl_vim, 'test_vimeo', epoch) visual_evaluation(model=G, quadratic=params['input_size'] == 4, writer=writer, epoch=epoch) visual_evaluation_vimeo(model=G, quadratic=params['input_size'] == 4, writer=writer, epoch=epoch) # save model if new best if early_stopping.new_best(): filepath_out = os.path.join(MODEL_FOLDER, '{0}_{1}') torch.save(G, filepath_out.format('generator', name)) # save last model state checkpoint = { 'last_model': G, 'epoch': epoch, 'optimizer': optimizer.state_dict(), 'name': name, 'scheduler': scheduler } torch.save(checkpoint, filepath_out.format('checkpoint', name)) if early_stopping.stop() and epoch >= FLAGS.min_epochs: break torch.cuda.empty_cache() end_time = time.time() # free memory del G torch.cuda.empty_cache() time_elapsed = end_time - start_time print(f'Ran {n_epochs} epochs in {round(time_elapsed, 1)} seconds') return results
def train(config=None, args=None, arch=None): graph = False modelfile = args.model trainloss = [] validloss = [] learningrate = [] torch.backends.cudnn.benchmark = True #torch.backends.cudnn.deterministic = True #torch.autograd.set_detect_anomaly(True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) print("Using training file:", config.trainfile) model = network(config=config, arch=arch, seqlen=config.seqlen).to(device) print("Model parameters:", sum(p.numel() for p in model.parameters())) if modelfile != None: print("Loading pretrained model:", modelfile) model.load_state_dict(torch.load(modelfile)) if args.verbose: print("Optimizer:", config.optimizer, "lr:", config.lr, "weightdecay", config.weightdecay) print("Scheduler:", config.scheduler, "patience:", config.scheduler_patience, "factor:", config.scheduler_factor, "threshold", config.scheduler_threshold, "minlr:", config.scheduler_minlr, "reduce:", config.scheduler_reduce) if config.optimizer.lower() == "adamw": optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weightdecay) elif config.optimizer.lower() == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) elif config.optimizer.lower() == "ranger": from pytorch_ranger import Ranger optimizer = Ranger(model.parameters(), lr=config.lr, weight_decay=config.weightdecay) if args.verbose: print(model) model.eval() with torch.no_grad(): fakedata = torch.rand((1, 1, config.seqlen)) fakeout = model.forward(fakedata.to(device)) elen = fakeout.shape[0] data = dataloader(recfile=config.trainfile, seq_len=config.seqlen, elen=elen) data_loader = DataLoader(dataset=data, batch_size=config.batchsize, shuffle=True, num_workers=args.workers, pin_memory=True) if config.scheduler == "reducelronplateau": scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=config.scheduler_patience, factor=config.scheduler_factor, verbose=args.verbose, threshold=config.scheduler_threshold, min_lr=config.scheduler_minlr) count = 0 last = None if config.amp: print("Using amp") from apex import amp model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0) if args.statedict: print("Loading pretrained model:", args.statedict) checkpoint = torch.load(args.statedict) model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) scheduler.load_state_dict(checkpoint["scheduler"]) # from Bonito but weighting for blank changed to 0.1 from 0.4 if args.labelsmoothing: C = len(config.vocab) smoothweights = torch.cat( [torch.tensor([0.1]), (0.1 / (C - 1)) * torch.ones(C - 1)]).to(device) if not os.path.isdir(args.savedir): os.mkdir(args.savedir) shutil.rmtree(args.savedir + "/" + config.name, True) if args.tensorboard: writer = SummaryWriter(args.savedir + "/" + config.name) if not graph: a, b, c, d = next(iter(data_loader)) a = torch.unsqueeze(a, 1) writer.add_graph(model, a.to(device)) #criterion = nn.CTCLoss(reduction="mean", zero_infinity=True) # test for epoch in range(config.epochs): model.train() totalloss = 0 loopcount = 0 learningrate.append(optimizer.param_groups[0]['lr']) if args.verbose: print("Learning rate:", learningrate[-1]) for i, (event, event_len, label, label_len) in enumerate(data_loader): event = torch.unsqueeze(event, 1) if event.shape[0] < config.batchsize: continue label = label[:, :max(label_len)] event = event.to(device, non_blocking=True) label = label.to(device, non_blocking=True) event_len = event_len.to(device, non_blocking=True) label_len = label_len.to(device, non_blocking=True) optimizer.zero_grad() out = model.forward(event) if args.labelsmoothing: losses = ont.ctc_label_smoothing_loss(out, label, label_len, smoothweights) loss = losses["ctc_loss"] else: loss = torch.nn.functional.ctc_loss( out, label, event_len, label_len, reduction="mean", blank=config.vocab.index('<PAD>'), zero_infinity=True) #loss = criterion(out, label, event_len, label_len) totalloss += loss.cpu().detach().numpy() print("Loss", loss.data, "epoch:", epoch, count, optimizer.param_groups[0]['lr']) if config.amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: if args.labelsmoothing: losses["loss"].backward() else: loss.backward() if config.gradclip: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), config.gradclip) optimizer.step() loopcount += 1 count += 1 if loopcount >= config.train_loopcount: break if args.tensorboard: tensorboard_writer_values(writer, model) if args.verbose: print("Train epoch loss", totalloss / loopcount) vl = validate(model, device, config=config, args=args, epoch=epoch, elen=elen) if config.scheduler == "reducelronplateau": scheduler.step(vl) elif config.scheduler == "decay": if (epoch > 0) and (epoch % config.scheduler_reduce == 0): optimizer.param_groups[0]['lr'] *= config.scheduler_factor if optimizer.param_groups[0]['lr'] < config.scheduler_minlr: optimizer.param_groups[0]['lr'] = config.scheduler_minlr trainloss.append(np.float(totalloss / loopcount)) validloss.append(vl) if args.tensorboard: tensorboard_writer_value(writer, "training loss", np.float(totalloss / loopcount)) tensorboard_writer_value(writer, "validation loss", vl) f = open(args.savedir + "/" + config.name + "-stats.pickle", "wb") pickle.dump([trainloss, validloss], f) pickle.dump(config.orig, f) pickle.dump(learningrate, f) f.close() torch.save( get_config(model, config.orig), args.savedir + "/" + config.name + "-epoch" + str(epoch) + ".torch") torch.save(get_checkpoint(epoch, model, optimizer, scheduler), args.savedir + "/" + config.name + "-ext.torch") if args.verbose: print("Train losses:", trainloss) print("Valid losses:", validloss) print("Learning rate:", learningrate) print("Model", config.name, "done.") return trainloss, validloss
g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) e_reg_ratio = args.e_reg_every / (args.e_reg_every + 1) g_optim = optim.Adam( generator.parameters(), lr=args.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), ) d_optim = optim.Adam( discriminator.parameters(), lr=args.lr * d_reg_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), ) e_optim = Ranger(encoder.parameters()) if args.ckpt is not None: print("load model:", args.ckpt) ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) try: ckpt_name = os.path.basename(args.ckpt) args.start_iter = int(os.path.splitext(ckpt_name)[0]) except ValueError: pass generator.load_state_dict(ckpt["g"]) discriminator.load_state_dict(ckpt["d"])
'n_slots': 8, 'discretize': 0, 'span_dropout': None, #span_dropout, } model_config.update({'char_kwargs': deepcopy(model_config)}) model_config['char_kwargs']['i'] = model_config['char_i'] model_config['char_kwargs']['o'] = model_config['char_i'] model_config['char_kwargs']['wd'] = None model_config['char_kwargs']['discretize'] = 0 model_config['char_kwargs']['char_level'] = True model_config['char_kwargs']['n_heads'] = 1 P = Parser(**model_config) opt = Ranger(P.parameters()) mse = nn.MSELoss() ce = nn.CrossEntropyLoss() data = pd.DataFrame({ 'text': list( filter( lambda x: (lambda y: y != [] and len(y[0]) <= limit) (preprocessor(x)), chain(*df['text'].apply(nltk.sent_tokenize).tolist()))) }) data = data.sample(len(data)) n_sentences = len(data)