Пример #1
0
def make_optimizer(model, cfg):

    assert cfg.SOLVER.OPTIMIZER in [
        'Adam', 'SGD', 'Ranger', 'RangerQH', 'RangerALR'
    ], 'Nome optimizer non riconosciuto!'

    if cfg.SOLVER.OPTIMIZER == 'Adam':
        return torch.optim.Adam(model.parameters(),
                                lr=cfg.SOLVER.LR,
                                weight_decay=cfg.SOLVER.WEIGHT_DECAY,
                                betas=cfg.SOLVER.BETAS,
                                amsgrad=cfg.SOLVER.AMSGRAD)
    elif cfg.SOLVER.OPTIMIZER == 'SGD':
        return torch.optim.SGD(model.parameters(),
                               lr=cfg.SOLVER.LR,
                               weight_decay=cfg.SOLVER.WEIGHT_DECAY,
                               nesterov=cfg.SOLVER.NESTEROS)

    elif cfg.SOLVER.OPTIMIZER == 'Ranger':
        return Ranger(model.parameters(),
                      lr=cfg.SOLVER.LR,
                      weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    elif cfg.SOLVER.OPTIMIZER == 'RangerQH':
        return RangerQH(model.parameters(),
                        lr=cfg.SOLVER.LR,
                        weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    elif cfg.SOLVER.OPTIMIZER == 'RangerALR':
        return RangerVA(model.parameters(),
                        lr=cfg.SOLVER.LR,
                        weight_decay=cfg.SOLVER.WEIGHT_DECAY,
                        amsgrad=cfg.SOLVER.AMSGRAD)
Пример #2
0
def train(params, n_epochs, verbose=True):
    # init interpolation model
    timestamp = int(time.time())
    formatted_params = '_'.join(f'{k}={v}' for k, v in params.items())

    torch.manual_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    if FLAGS.filename == None:
        G = SepConvNetExtended(kl_size=params['kl_size'],
                               kq_size=params['kq_size'],
                               kl_d_size=params['kl_d_size'],
                               kl_d_scale=params['kl_d_scale'],
                               kq_d_scale=params['kq_d_scale'],
                               kq_d_size=params['kq_d_size'],
                               input_frames=params['input_size'])

        if params['pretrain'] in [1, 2]:
            print('LOADING L1')
            G.load_weights('l1')
        name = f'{timestamp}_seed_{FLAGS.seed}_{formatted_params}'
        G = torch.nn.DataParallel(G).cuda()

        # optimizer = torch.optim.Adamax(G.parameters(), lr=params['lr'], betas=(.9, .999))
        if params['optimizer'] == 'ranger':
            optimizer = Ranger([{
                'params':
                [p for l, p in G.named_parameters() if 'moduleConv' not in l]
            }, {
                'params':
                [p for l, p in G.named_parameters() if 'moduleConv' in l],
                'lr':
                params['lr2']
            }],
                               lr=params['lr'],
                               betas=(.95, .999))

        elif params['optimizer'] == 'adamax':
            optimizer = torch.optim.Adamax([{
                'params':
                [p for l, p in G.named_parameters() if 'moduleConv' not in l]
            }, {
                'params':
                [p for l, p in G.named_parameters() if 'moduleConv' in l],
                'lr':
                params['lr2']
            }],
                                           lr=params['lr'],
                                           betas=(.9, .999))

        else:
            raise NotImplementedError()

        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=60 - FLAGS.warmup, T_mult=1, eta_min=1e-5)

        start_epoch = 0

    else:
        checkpoint = torch.load(FLAGS.filename)
        G = checkpoint['last_model'].cuda()
        start_epoch = checkpoint['epoch'] + 1
        name = checkpoint['name']

        optimizer = torch.optim.Adamax([{
            'params':
            [p for l, p in G.named_parameters() if 'moduleConv' not in l]
        }, {
            'params':
            [p for l, p in G.named_parameters() if 'moduleConv' in l],
            'lr':
            params['lr2']
        }],
                                       lr=params['lr'],
                                       betas=(.9, .999))

        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               T_max=n_epochs -
                                                               FLAGS.warmup,
                                                               eta_min=1e-5,
                                                               last_epoch=-1)

        for _ in range(start_epoch - FLAGS.warmup + 1):
            scheduler.step()

    print('SETTINGS:')
    print(params)
    print('NAME:')
    print(name)
    sys.stdout.flush()

    # loss_network = losses.LossNetwork(layers=[9,16,26]).cuda() #9, 16, 26
    # Perc_loss = losses.PerceptualLoss(loss_network, include_input=True)

    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)

    quadratic = params['input_size'] == 4

    L1_loss = torch.nn.L1Loss().cuda()
    # Flow_loss = losses.FlowLoss(quadratic=quadratic).cuda()

    ds_train_lmd = dataloader.large_motion_dataset(quadratic=quadratic,
                                                   cropped=True,
                                                   fold='train',
                                                   min_flow=6)
    ds_valid_lmd = dataloader.large_motion_dataset(quadratic=quadratic,
                                                   cropped=True,
                                                   fold='valid')

    ds_vimeo_train = dataloader.vimeo90k_dataset(fold='train',
                                                 quadratic=quadratic)
    ds_vimeo_test = dataloader.vimeo90k_dataset(fold='test',
                                                quadratic=quadratic)

    # train, test_lmd = dataloader.split_data(ds_lmd, [.9, .1])
    train_vimeo, valid_vimeo = dataloader.split_data(ds_vimeo_train, [.9, .1])

    # torch.manual_seed(FLAGS.seed)
    # np.random.seed(FLAGS.seed)
    # random.seed(FLAGS.seed)

    train_settings = {
        'flip_probs': FLAGS.flip_probs,
        'normalize': True,
        'crop_size': (FLAGS.crop_size, FLAGS.crop_size),
        'jitter_prob': FLAGS.jitter_prob,
        'random_rescale_prob': FLAGS.random_rescale_prob
        # 'rescale_distr':(.8, 1.2),
    }

    valid_settings = {
        'flip_probs': 0,
        'random_rescale_prob': 0,
        'random_crop': False,
        'normalize': True
    }

    train_lmd = dataloader.TransformedDataset(ds_train_lmd, **train_settings)
    valid_lmd = dataloader.TransformedDataset(ds_valid_lmd, **valid_settings)

    train_vimeo = dataloader.TransformedDataset(train_vimeo, **train_settings)
    valid_vimeo = dataloader.TransformedDataset(valid_vimeo, **valid_settings)
    test_vimeo = dataloader.TransformedDataset(ds_vimeo_test, **valid_settings)

    train_data = torch.utils.data.ConcatDataset([train_lmd, train_vimeo])

    # displacement
    df = pd.read_csv(f'hardinstancesinfo/vimeo90k_test_flow.csv')
    test_disp = torch.utils.data.Subset(
        ds_vimeo_test,
        indices=df[df.mean_manh_flow >= df.quantile(.9).mean_manh_flow].index.
        tolist())
    test_disp = dataloader.TransformedDataset(test_disp, **valid_settings)
    test_disp = torch.utils.data.DataLoader(test_disp,
                                            batch_size=4,
                                            pin_memory=True)

    # nonlinearity
    df = pd.read_csv(f'hardinstancesinfo/Vimeo90K_test.csv')
    test_nonlin = torch.utils.data.Subset(
        ds_vimeo_test,
        indices=df[
            df.non_linearity >= df.quantile(.9).non_linearity].index.tolist())
    test_nonlin = dataloader.TransformedDataset(test_nonlin, **valid_settings)
    test_nonlin = torch.utils.data.DataLoader(test_nonlin,
                                              batch_size=4,
                                              pin_memory=True)

    # create weights for train sampler
    df_vim = pd.read_csv(f'hardinstancesinfo/vimeo90k_train_flow.csv')
    weights_vim = df_vim[df_vim.index.isin(
        train_vimeo.dataset.indices)].mean_manh_flow.tolist()
    weights_lmd = ds_train_lmd.weights
    train_sampler = torch.utils.data.sampler.WeightedRandomSampler(
        weights_lmd + weights_vim, FLAGS.num_train_samples, replacement=False)

    train_dl = torch.utils.data.DataLoader(train_data,
                                           batch_size=FLAGS.batch_size,
                                           pin_memory=True,
                                           shuffle=False,
                                           sampler=train_sampler,
                                           num_workers=FLAGS.num_workers)
    valid_dl_vim = torch.utils.data.DataLoader(valid_vimeo,
                                               batch_size=4,
                                               pin_memory=True,
                                               num_workers=FLAGS.num_workers)
    valid_dl_lmd = torch.utils.data.DataLoader(valid_lmd,
                                               batch_size=4,
                                               pin_memory=True,
                                               num_workers=FLAGS.num_workers)
    test_dl_vim = torch.utils.data.DataLoader(test_vimeo,
                                              batch_size=4,
                                              pin_memory=True,
                                              num_workers=FLAGS.num_workers)

    # metrics
    writer = SummaryWriter(f'runs/final_exp/full_run_losses/{name}')

    results = ResultStore(writer=writer,
                          metrics=['psnr', 'ssim', 'ie', 'L1_loss', 'lf'],
                          folds=FOLDS)

    early_stopping_metric = 'L1_loss'
    early_stopping = EarlyStopping(results,
                                   patience=FLAGS.patience,
                                   metric=early_stopping_metric,
                                   fold='valid_vimeo')

    loss_network = losses.LossNetwork(layers=[26]).cuda()  #9, 16, 26
    Perc_loss = losses.PerceptualLoss(loss_network).cuda()

    def do_epoch(dataloader, fold, epoch, train=False):
        assert fold in FOLDS

        if verbose:
            pb = tqdm(desc=f'{fold} {epoch+1}/{n_epochs}',
                      total=len(dataloader),
                      leave=True,
                      position=0)

        for i, (X, y) in enumerate(dataloader):
            X = X.cuda()
            y = y.cuda()

            y_hat = G(X)

            l1_loss = L1_loss(y_hat, y)
            feature_loss = Perc_loss(y_hat, y)

            lf_loss = l1_loss + feature_loss

            if train:
                optimizer.zero_grad()
                lf_loss.backward()
                optimizer.step()

            # compute metrics
            y_hat = (y_hat * 255).clamp(0, 255)
            y = (y * 255).clamp(0, 255)

            psnr = metrics.psnr(y_hat, y)
            ssim = metrics.ssim(y_hat, y)
            ie = metrics.interpolation_error(y_hat, y)

            results.store(
                fold, epoch, {
                    'L1_loss': l1_loss.item(),
                    'psnr': psnr,
                    'ssim': ssim,
                    'ie': ie,
                    'lf': lf_loss.item()
                })

            if verbose: pb.update()

        # update tensorboard
        results.write_tensorboard(fold, epoch)
        sys.stdout.flush()

    start_time = time.time()
    for epoch in range(start_epoch, n_epochs):

        G.train()
        do_epoch(train_dl, 'train_fold', epoch, train=True)

        if epoch >= FLAGS.warmup - 1:
            scheduler.step()

        G.eval()
        with torch.no_grad():
            do_epoch(valid_dl_vim, 'valid_vimeo', epoch)
            do_epoch(valid_dl_lmd, 'valid_lmd', epoch)

        if (early_stopping.stop() and epoch >= FLAGS.min_epochs
            ) or epoch % FLAGS.test_every == 0 or epoch + 1 == n_epochs:
            with torch.no_grad():
                do_epoch(test_disp, 'test_disp', epoch)
                do_epoch(test_nonlin, 'test_nonlin', epoch)

                do_epoch(test_dl_vim, 'test_vimeo', epoch)

            visual_evaluation(model=G,
                              quadratic=params['input_size'] == 4,
                              writer=writer,
                              epoch=epoch)

            visual_evaluation_vimeo(model=G,
                                    quadratic=params['input_size'] == 4,
                                    writer=writer,
                                    epoch=epoch)

        # save model if new best
        if early_stopping.new_best():
            filepath_out = os.path.join(MODEL_FOLDER, '{0}_{1}')
            torch.save(G, filepath_out.format('generator', name))

        # save last model state
        checkpoint = {
            'last_model': G,
            'epoch': epoch,
            'optimizer': optimizer.state_dict(),
            'name': name,
            'scheduler': scheduler
        }
        torch.save(checkpoint, filepath_out.format('checkpoint', name))

        if early_stopping.stop() and epoch >= FLAGS.min_epochs:
            break

        torch.cuda.empty_cache()

    end_time = time.time()
    # free memory
    del G
    torch.cuda.empty_cache()
    time_elapsed = end_time - start_time
    print(f'Ran {n_epochs} epochs in {round(time_elapsed, 1)} seconds')

    return results
Пример #3
0
def train(config=None, args=None, arch=None):
    graph = False
    modelfile = args.model
    trainloss = []
    validloss = []
    learningrate = []

    torch.backends.cudnn.benchmark = True
    #torch.backends.cudnn.deterministic = True
    #torch.autograd.set_detect_anomaly(True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    print("Using training file:", config.trainfile)

    model = network(config=config, arch=arch, seqlen=config.seqlen).to(device)

    print("Model parameters:", sum(p.numel() for p in model.parameters()))
    if modelfile != None:
        print("Loading pretrained model:", modelfile)
        model.load_state_dict(torch.load(modelfile))

    if args.verbose:
        print("Optimizer:", config.optimizer, "lr:", config.lr, "weightdecay",
              config.weightdecay)
        print("Scheduler:", config.scheduler, "patience:",
              config.scheduler_patience, "factor:", config.scheduler_factor,
              "threshold", config.scheduler_threshold, "minlr:",
              config.scheduler_minlr, "reduce:", config.scheduler_reduce)

    if config.optimizer.lower() == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=config.lr,
                                      weight_decay=config.weightdecay)
    elif config.optimizer.lower() == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    elif config.optimizer.lower() == "ranger":
        from pytorch_ranger import Ranger
        optimizer = Ranger(model.parameters(),
                           lr=config.lr,
                           weight_decay=config.weightdecay)

    if args.verbose: print(model)

    model.eval()
    with torch.no_grad():
        fakedata = torch.rand((1, 1, config.seqlen))
        fakeout = model.forward(fakedata.to(device))
        elen = fakeout.shape[0]

    data = dataloader(recfile=config.trainfile,
                      seq_len=config.seqlen,
                      elen=elen)
    data_loader = DataLoader(dataset=data,
                             batch_size=config.batchsize,
                             shuffle=True,
                             num_workers=args.workers,
                             pin_memory=True)

    if config.scheduler == "reducelronplateau":
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            patience=config.scheduler_patience,
            factor=config.scheduler_factor,
            verbose=args.verbose,
            threshold=config.scheduler_threshold,
            min_lr=config.scheduler_minlr)

    count = 0
    last = None

    if config.amp:
        print("Using amp")
        from apex import amp
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)

    if args.statedict:
        print("Loading pretrained model:", args.statedict)
        checkpoint = torch.load(args.statedict)
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])

    # from Bonito but weighting for blank changed to 0.1 from 0.4
    if args.labelsmoothing:
        C = len(config.vocab)
        smoothweights = torch.cat(
            [torch.tensor([0.1]),
             (0.1 / (C - 1)) * torch.ones(C - 1)]).to(device)

    if not os.path.isdir(args.savedir):
        os.mkdir(args.savedir)

    shutil.rmtree(args.savedir + "/" + config.name, True)
    if args.tensorboard:
        writer = SummaryWriter(args.savedir + "/" + config.name)
        if not graph:
            a, b, c, d = next(iter(data_loader))
            a = torch.unsqueeze(a, 1)
            writer.add_graph(model, a.to(device))

    #criterion = nn.CTCLoss(reduction="mean", zero_infinity=True) # test

    for epoch in range(config.epochs):
        model.train()
        totalloss = 0
        loopcount = 0
        learningrate.append(optimizer.param_groups[0]['lr'])
        if args.verbose: print("Learning rate:", learningrate[-1])

        for i, (event, event_len, label, label_len) in enumerate(data_loader):
            event = torch.unsqueeze(event, 1)
            if event.shape[0] < config.batchsize: continue

            label = label[:, :max(label_len)]
            event = event.to(device, non_blocking=True)
            label = label.to(device, non_blocking=True)
            event_len = event_len.to(device, non_blocking=True)
            label_len = label_len.to(device, non_blocking=True)

            optimizer.zero_grad()

            out = model.forward(event)

            if args.labelsmoothing:
                losses = ont.ctc_label_smoothing_loss(out, label, label_len,
                                                      smoothweights)
                loss = losses["ctc_loss"]
            else:
                loss = torch.nn.functional.ctc_loss(
                    out,
                    label,
                    event_len,
                    label_len,
                    reduction="mean",
                    blank=config.vocab.index('<PAD>'),
                    zero_infinity=True)
                #loss = criterion(out, label, event_len, label_len)

            totalloss += loss.cpu().detach().numpy()
            print("Loss", loss.data, "epoch:", epoch, count,
                  optimizer.param_groups[0]['lr'])

            if config.amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                if args.labelsmoothing:
                    losses["loss"].backward()
                else:
                    loss.backward()

            if config.gradclip:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), config.gradclip)

            optimizer.step()
            loopcount += 1
            count += 1
            if loopcount >= config.train_loopcount: break

        if args.tensorboard: tensorboard_writer_values(writer, model)
        if args.verbose: print("Train epoch loss", totalloss / loopcount)

        vl = validate(model,
                      device,
                      config=config,
                      args=args,
                      epoch=epoch,
                      elen=elen)

        if config.scheduler == "reducelronplateau":
            scheduler.step(vl)
        elif config.scheduler == "decay":
            if (epoch > 0) and (epoch % config.scheduler_reduce == 0):
                optimizer.param_groups[0]['lr'] *= config.scheduler_factor
                if optimizer.param_groups[0]['lr'] < config.scheduler_minlr:
                    optimizer.param_groups[0]['lr'] = config.scheduler_minlr

        trainloss.append(np.float(totalloss / loopcount))
        validloss.append(vl)

        if args.tensorboard:
            tensorboard_writer_value(writer, "training loss",
                                     np.float(totalloss / loopcount))
            tensorboard_writer_value(writer, "validation loss", vl)

        f = open(args.savedir + "/" + config.name + "-stats.pickle", "wb")
        pickle.dump([trainloss, validloss], f)
        pickle.dump(config.orig, f)
        pickle.dump(learningrate, f)
        f.close()

        torch.save(
            get_config(model, config.orig), args.savedir + "/" + config.name +
            "-epoch" + str(epoch) + ".torch")
        torch.save(get_checkpoint(epoch, model, optimizer, scheduler),
                   args.savedir + "/" + config.name + "-ext.torch")

        if args.verbose:
            print("Train losses:", trainloss)
            print("Valid losses:", validloss)
            print("Learning rate:", learningrate)

    print("Model", config.name, "done.")
    return trainloss, validloss
Пример #4
0
    g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
    d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)
    e_reg_ratio = args.e_reg_every / (args.e_reg_every + 1)

    g_optim = optim.Adam(
        generator.parameters(),
        lr=args.lr * g_reg_ratio,
        betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
    )
    d_optim = optim.Adam(
        discriminator.parameters(),
        lr=args.lr * d_reg_ratio,
        betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
    )
    e_optim = Ranger(encoder.parameters())

    if args.ckpt is not None:
        print("load model:", args.ckpt)

        ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)

        try:
            ckpt_name = os.path.basename(args.ckpt)
            args.start_iter = int(os.path.splitext(ckpt_name)[0])

        except ValueError:
            pass

        generator.load_state_dict(ckpt["g"])
        discriminator.load_state_dict(ckpt["d"])
Пример #5
0
    'n_slots': 8,
    'discretize': 0,
    'span_dropout': None,  #span_dropout,
}

model_config.update({'char_kwargs': deepcopy(model_config)})
model_config['char_kwargs']['i'] = model_config['char_i']
model_config['char_kwargs']['o'] = model_config['char_i']
model_config['char_kwargs']['wd'] = None
model_config['char_kwargs']['discretize'] = 0
model_config['char_kwargs']['char_level'] = True
model_config['char_kwargs']['n_heads'] = 1

P = Parser(**model_config)

opt = Ranger(P.parameters())

mse = nn.MSELoss()
ce = nn.CrossEntropyLoss()

data = pd.DataFrame({
    'text':
    list(
        filter(
            lambda x: (lambda y: y != [] and len(y[0]) <= limit)
            (preprocessor(x)),
            chain(*df['text'].apply(nltk.sent_tokenize).tolist())))
})
data = data.sample(len(data))

n_sentences = len(data)