Esempio n. 1
0
def get_yogi(parameters, configs):
    print('Prediction using YOGI optimizer')
    if "betas" in configs.keys():
        print("Set betas to values from the config file: ")
        print(*configs["betas"], sep=", ")
        return optim.Yogi(parameters,
                          lr=configs['learning_rate'],
                          betas=configs["betas"])

    else:
        return optim.Yogi(parameters, lr=configs['learning_rate'])
Esempio n. 2
0
def get_optimizer(optimizer: str, model, optimizer_args):
    if optimizer == "sgd":
        return torch.optim.SGD(model.parameters(), **optimizer_args)
    elif optimizer == "adam":
        return torch.optim.Adam(model.parameters(), **optimizer_args)
    elif optimizer == "yogi":
        return optim.Yogi(model.parameters(), **optimizer_args)
    elif optimizer == "shampoo":
        return optim.Shampoo(model.parameters(), **optimizer_args)
    elif optimizer == "swats":
        return optim.SWATS(model.parameters(), **optimizer_args)
    elif optimizer == "sgdw":
        return optim.SGDW(model.parameters(), **optimizer_args)
    elif optimizer == "sgdp":
        return optim.SGDP(model.parameters(), **optimizer_args)
    elif optimizer == "rangerva":
        return optim.RangerVA(model.parameters(), **optimizer_args)
    elif optimizer == "rangerqh":
        return optim.RangerQH(model.parameters(), **optimizer_args)
    elif optimizer == "ranger":
        return optim.Ranger(model.parameters(), **optimizer_args)
    elif optimizer == "radam":
        return optim.RAdam(model.parameters(), **optimizer_args)
    elif optimizer == "qhm":
        return optim.QHM(model.parameters(), **optimizer_args)
    elif optimizer == "qhadam":
        return optim.QHAdam(model.parameters(), **optimizer_args)
    elif optimizer == "pid":
        return optim.PID(model.parameters(), **optimizer_args)
    elif optimizer == "novograd":
        return optim.NovoGrad(model.parameters(), **optimizer_args)
    elif optimizer == "lamb":
        return optim.Lamb(model.parameters(), **optimizer_args)
    elif optimizer == "diffgrad":
        return optim.DiffGrad(model.parameters(), **optimizer_args)
    elif optimizer == "apollo":
        return optim.Apollo(model.parameters(), **optimizer_args)
    elif optimizer == "aggmo":
        return optim.AggMo(model.parameters(), **optimizer_args)
    elif optimizer == "adamp":
        return optim.AdamP(model.parameters(), **optimizer_args)
    elif optimizer == "adafactor":
        return optim.Adafactor(model.parameters(), **optimizer_args)
    elif optimizer == "adamod":
        return optim.AdaMod(model.parameters(), **optimizer_args)
    elif optimizer == "adabound":
        return optim.AdaBound(model.parameters(), **optimizer_args)
    elif optimizer == "adabelief":
        return optim.AdaBelief(model.parameters(), **optimizer_args)
    elif optimizer == "accsgd":
        return optim.AccSGD(model.parameters(), **optimizer_args)
    elif optimizer == "a2graduni":
        return optim.A2GradUni(model.parameters(), **optimizer_args)
    elif optimizer == "a2gradinc":
        return optim.A2GradInc(model.parameters(), **optimizer_args)
    elif optimizer == "a2gradexp":
        return optim.A2GradExp(model.parameters(), **optimizer_args)
    else:
        raise Exception(f"Optimizer '{optimizer}' does not exist!")
Esempio n. 3
0
def create_optimizer(arg, parameters, create_scheduler=False, discrim=False):
    lr = arg.lr_discrim if discrim else arg.lr
    weight_decay = arg.weight_decay_discrim if discrim else arg.weight_decay
    if arg.optimizer == 'Lamb':
        optimizer = optim.Lamb(parameters,
                               lr=lr,
                               weight_decay=weight_decay,
                               betas=(0.5, 0.999))
    elif arg.optimizer == 'AdaBound':
        optimizer = optim.AdaBound(parameters,
                                   lr=lr,
                                   weight_decay=weight_decay,
                                   betas=(0.5, 0.999))
    elif arg.optimizer == 'Yogi':
        optimizer = optim.Yogi(parameters,
                               lr=lr,
                               weight_decay=weight_decay,
                               betas=(0.5, 0.999))
    elif arg.optimizer == 'DiffGrad':
        optimizer = optim.DiffGrad(parameters,
                                   lr=lr,
                                   weight_decay=weight_decay,
                                   betas=(0.5, 0.999))
    elif arg.optimizer == 'Adam':
        optimizer = torch.optim.Adam(parameters,
                                     lr=lr,
                                     weight_decay=weight_decay,
                                     betas=(0.5, 0.999))
    else:
        optimizer = torch.optim.SGD(parameters,
                                    lr=lr,
                                    momentum=arg.momentum,
                                    weight_decay=weight_decay)

    if create_scheduler:
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   factor=0.2,
                                                   patience=4,
                                                   threshold=1e-2,
                                                   verbose=True)
    else:
        scheduler = None

    return optimizer, scheduler
Esempio n. 4
0
def create_optimizer(optimizer_config: Dict[str, Any], model: nn.Module):
    cp: Dict[str, Any] = copy(optimizer_config)
    n = cp.pop("name").lower()

    if n == "adam":
        optimizer: Optimizer = optim.Adam(model.parameters(), **cp)
    elif n == "sgd":
        optimizer = optim.SGD(model.parameters(), **cp)
    elif n == "adabound":
        optimizer = torch_optimizer.AdaBound(model.parameters(), **cp)
    elif n == "diffgrad":
        optimizer = torch_optimizer.DiffGrad(model.parameters(), **cp)
    elif n == "qhadam":
        optimizer = torch_optimizer.QHAdam(model.parameters(), **cp)
    elif n == "radam":
        optimizer = torch_optimizer.RAdam(model.parameters(), **cp)
    elif n == "yogi":
        optimizer = torch_optimizer.Yogi(model.parameters(), **cp)
    else:
        raise ValueError(n)

    return optimizer
Esempio n. 5
0
    def fit(self,
            x,
            y,
            n_epoch=100,
            batch_size=128,
            lr=0.001,
            weight_decay=0,
            dropout=0,
            verbose=False):
        self.dropout = dropout

        x_train, x_val, y_train, y_val, m_train, m_val = self.train_val_split(
            x, y)
        optimizer = optim.Yogi(self.parameters(),
                               lr=lr,
                               weight_decay=weight_decay)

        val_loss = []
        for epoch in range(n_epoch):
            mb = self.get_mini_batches(x_train,
                                       y_train,
                                       m_train,
                                       batch_size=batch_size)
            self.train()
            for x_mb, y_mb, m_mb in mb:
                loss = self.loss_batch(x_mb, y_mb, m_mb, optimizer=optimizer)

            self.eval()
            with torch.no_grad():
                loss = self.loss_batch(x_val, y_val, m_val, optimizer=None)
                val_loss.append(loss)

            min_loss_idx = val_loss.index(min(val_loss))
            if min_loss_idx == epoch:
                best_parameters = self.state_dict()
                if verbose:
                    print(epoch, loss)
        self.load_state_dict(best_parameters, strict=True)
        return None
Esempio n. 6
0
def build_lookahead(*a, **kw):
    base = optim.Yogi(*a, **kw)
    return optim.Lookahead(base)
Esempio n. 7
0
def LookaheadYogi(*a, **kw):
    base = optim.Yogi(*a, **kw)
    return optim.Lookahead(base)
Esempio n. 8
0
def train(task_id,
          train_set,
          val_set,
          test_set,
          map_est_hypers=False,
          epochs=1,
          M=20,
          n_f=10,
          n_var_samples=3,
          batch_size=512,
          lr=1e-2,
          beta=1.0,
          eval_interval=10,
          patience=20,
          prev_params=None,
          logger=None,
          device=None):
    gp = create_class_gp(train_set,
                         M=M,
                         n_f=n_f,
                         n_var_samples=n_var_samples,
                         map_est_hypers=map_est_hypers,
                         prev_params=prev_params).to(device)

    stopper = EarlyStopper(patience=patience)

    # optim = torch.optim.Adam(gp.parameters(), lr=lr)
    optim = torch_optimizer.Yogi(gp.parameters(), lr=lr)

    N = len(train_set)
    loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

    for e in tqdm(range(epochs)):
        for x, y in tqdm(loader, leave=False):
            optim.zero_grad()

            kl_hypers, kl_u, u_prev_reg, lik = gp.loss(x.to(device),
                                                       y.to(device))

            loss = beta * kl_hypers + kl_u - u_prev_reg + (N / x.size(0)) * lik
            loss.backward()

            optim.step()

        if (e + 1) % eval_interval == 0:
            train_acc = compute_accuracy(train_set, gp, device=device)
            val_acc = compute_accuracy(val_set, gp, device=device)
            test_acc = compute_accuracy(test_set, gp, device=device)

            loss_summary = {
                f'task{task_id}/loss/kl_hypers': kl_hypers.detach().item(),
                f'task{task_id}/loss/kl_u': kl_u.detach().item(),
                f'task{task_id}/loss/lik': lik.detach().item()
            }

            acc_summary = {
                f'task{task_id}/train/acc': train_acc,
                f'task{task_id}/val/acc': val_acc,
                f'task{task_id}/test/acc': test_acc,
            }

            if logger is not None:
                for k, v in (dict(**loss_summary, **acc_summary)).items():
                    logger.add_scalar(k, v, global_step=e + 1)

            stopper(
                val_acc,
                dict(state_dict=gp.state_dict(),
                     acc_summary=acc_summary,
                     step=e + 1))
            if stopper.is_done():
                break

    info = stopper.info()
    if logger is not None:
        for k, v in info.get('acc_summary').items():
            logger.add_scalar(f'{k}_best', v, global_step=info.get('step'))

        with open(f'{logger.log_dir}/ckpt{task_id}.pt', 'wb') as f:
            torch.save(info.get('state_dict'), f)
        wandb.save(f'{logger.log_dir}/ckpt{task_id}.pt')

    return info.get('state_dict')
Esempio n. 9
0
def build_optimizer(cfg, model):
    name_optimizer = cfg.optimizer.type
    optimizer = None

    if name_optimizer == 'A2GradExp':
        optimizer = optim.A2GradExp(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'A2GradInc':
        optimizer = optim.A2GradInc(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'A2GradUni':
        optimizer = optim.A2GradUni(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'AccSGD':
        optimizer = optim.AccSGD(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'AdaBelief':
        optimizer = optim.AdaBelief(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'AdaBound':
        optimizer = optim.AdaBound(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'AdaMod':
        optimizer = optim.AdaMod(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'Adafactor':
        optimizer = optim.Adafactor(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'AdamP':
        optimizer = optim.AdamP(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'AggMo':
        optimizer = optim.AggMo(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'Apollo':
        optimizer = optim.Apollo(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'DiffGrad':
        optimizer = optim.DiffGrad(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'Lamb':
        optimizer = optim.Lamb(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'Lookahead':
        yogi = optim.Yogi(model.parameters(), lr=cfg.optimizer.lr)
        optimizer = optim.Lookahead(yogi, k=5, alpha=0.5)
    elif name_optimizer == 'NovoGrad':
        optimizer = optim.NovoGrad(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'PID':
        optimizer = optim.PID(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'QHAdam':
        optimizer = optim.QHAdam(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'QHM':
        optimizer = optim.QHM(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'RAdam':
        optimizer = optim.RAdam(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'Ranger':
        optimizer = optim.Ranger(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'RangerQH':
        optimizer = optim.RangerQH(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'RangerVA':
        optimizer = optim.RangerVA(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'SGDP':
        optimizer = optim.SGDP(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'SGDW':
        optimizer = optim.SGDW(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'SWATS':
        optimizer = optim.SWATS(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'Shampoo':
        optimizer = optim.Shampoo(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'Yogi':
        optimizer = optim.Yogi(model.parameters(), lr=cfg.optimizer.lr)
    elif name_optimizer == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=cfg.optimizer.lr,
                                     weight_decay=cfg.optimizer.weight_decay)
    elif name_optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=cfg.optimizer.lr,
                                    momentum=cfg.optimizer.momentum,
                                    weight_decay=cfg.optimizer.weight_decay)
    if optimizer is None:
        raise Exception('optimizer is wrong')
    return optimizer
Esempio n. 10
0
def main():
    global opts, global_step
    opts = parser.parse_args()
    opts.cuda = 0

    global_step = 0

    print(opts)

    # Set GPU
    seed = opts.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_ids
    use_gpu = torch.cuda.is_available()
    if use_gpu:
        opts.cuda = 1
        print("Currently using GPU {}".format(opts.gpu_ids))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(seed)

    else:
        print("Currently using CPU (GPU is highly recommended)")

    ######################################################################
    # Set this value to True if we use "MixSim_Model".
    # "MixSim_Model" sends its parameters to gpu devices on its own for model parallel.
    # Doesn't apply for "MixSim_Model_Single", which uses single gpu. So set is_mixsim = False.
    ######################################################################
    is_mixsim = True
    need_pretraining = True

    # Set model
    model = MixSim_Model(NUM_CLASSES, opts.gpu_ids.split(','))
    model.eval()

    # set EMA model
    ema_model = MixSim_Model(NUM_CLASSES, opts.gpu_ids.split(','))
    for param in ema_model.parameters():
        param.detach_()
    ema_model.eval()

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    n_parameters = sum([p.data.nelement() for p in model.parameters()])
    print('  + Number of params: {}'.format(n_parameters))

    # "MixSim_Model" sends its parameters to gpu devices on its own.
    if use_gpu and (not is_mixsim):
        model.cuda()
        ema_model.cuda()

    model_for_test = ema_model  # change this to model if ema_model is not used.

    ### DO NOT MODIFY THIS BLOCK ###
    if IS_ON_NSML:
        bind_nsml(model_for_test)
        if opts.pause:
            nsml.paused(scope=locals())
    ################################

    if opts.mode == 'train':
        # set multi-gpu (We won't use this with MixSim_Model)
        if len(opts.gpu_ids.split(',')) > 1 and (not is_mixsim):
            model = nn.DataParallel(model)
            ema_model = nn.DataParallel(ema_model)
        model.train()
        ema_model.train()

        ######################################################################
        # Data Augmentation for train data and unlabeled data
        ######################################################################
        data_transforms = transforms.Compose([
            transforms.RandomResizedCrop(opts.imsize, interpolation=3),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply(
                [transforms.ColorJitter(0.7, 0.7, 0.7, 0.2)], p=0.5),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

        # Set dataloader
        train_ids, val_ids, unl_ids = split_ids(
            os.path.join(DATASET_PATH, 'train/train_label'), 0.2)
        print('found {} train, {} validation and {} unlabeled images'.format(
            len(train_ids), len(val_ids), len(unl_ids)))
        train_loader = torch.utils.data.DataLoader(SimpleImageLoader(
            DATASET_PATH, 'train', train_ids, transform=data_transforms),
                                                   batch_size=opts.batchsize,
                                                   shuffle=True,
                                                   num_workers=0,
                                                   pin_memory=True,
                                                   drop_last=True)
        print('train_loader done')

        unlabel_loader = torch.utils.data.DataLoader(
            SimpleImageLoader(DATASET_PATH,
                              'unlabel',
                              unl_ids,
                              transform=data_transforms),
            batch_size=opts.batchsize * opts.unlabelratio,
            shuffle=True,
            num_workers=0,
            pin_memory=True,
            drop_last=True)
        print('unlabel_loader done')

        validation_loader = torch.utils.data.DataLoader(
            SimpleImageLoader(DATASET_PATH,
                              'val',
                              val_ids,
                              transform=transforms.Compose([
                                  transforms.Resize(opts.imResize,
                                                    interpolation=3),
                                  transforms.CenterCrop(opts.imsize),
                                  transforms.ToTensor(),
                                  transforms.Normalize(
                                      mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225]),
                              ])),
            batch_size=opts.batchsize,
            shuffle=False,
            num_workers=0,
            pin_memory=True,
            drop_last=False)
        print('validation_loader done')

        if opts.steps_per_epoch < 0:
            opts.steps_per_epoch = len(train_loader)

        ######################################################################
        # Set Optimizer
        # Adamax and Yogi are optimization alogorithms based on Adam with more effective learning rate control.
        # LARS is layer-wise adaptive rate scaling
        # LARSWrapper, which is optimizer wraaper that uses LARS algorithms, helps stability with huge batch size.
        ######################################################################
        # optimizer = optim.Adam(model.parameters(), lr=opts.lr, weight_decay=5e-4)
        # optimizer = optim.Adamax(model.parameters(), lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
        # optimizer = LARSWrapper(t_optim.Yogi(model.parameters(), lr=0.01, eps=opts.optimizer_eps))
        optimizer = t_optim.Yogi(model.parameters(),
                                 lr=opts.optimizer_lr,
                                 eps=opts.optimizer_eps)
        ema_optimizer = WeightEMA(model,
                                  ema_model,
                                  lr=opts.ema_optimizer_lr,
                                  alpha=opts.ema_decay)

        # INSTANTIATE LOSS CLASS
        train_criterion_pre = NCELoss()
        train_criterion_fine = NCELoss()
        train_criterion_distill = SemiLoss()

        ######################################################################
        # INSTANTIATE STEP LEARNING SCHEDULER CLASS
        ######################################################################
        # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,  milestones=[50, 150], gamma=0.1)
        # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5)
        # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, eps=1e-3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=opts.steps_per_epoch * opts.epochs // 10)

        # Train and Validation
        best_acc = -1
        best_weight_acc = [-1] * 5
        is_weighted_best = [False] * 5
        for epoch in range(opts.start_epoch, opts.epochs + 1):
            # print('start training')
            if (need_pretraining and epoch <= opts.pre_train_epoch):
                pre_loss = train_pre(opts, unlabel_loader, model,
                                     train_criterion_pre, optimizer,
                                     ema_optimizer, epoch, use_gpu, scheduler,
                                     is_mixsim)
                print(
                    'epoch {:03d}/{:03d} finished, pre_loss: {:.3f}:pre-training'
                    .format(epoch, opts.epochs, pre_loss))
                continue
            elif (need_pretraining
                  and epoch <= opts.pre_train_epoch + opts.fine_tune_epoch):
                loss, avg_top1, avg_top5 = train_fine(
                    opts, train_loader, model, train_criterion_fine, optimizer,
                    ema_optimizer, epoch, use_gpu, scheduler, is_mixsim)
                print(
                    'epoch {:03d}/{:03d} finished, loss: {:.3f}, avg_top1: {:.3f}%, avg_top5: {:.3f}%: fine-tuning'
                    .format(epoch, opts.epochs, loss, avg_top1, avg_top5))
                continue
            else:
                loss, loss_x, loss_u, avg_top1, avg_top5 = train_distill(
                    opts, train_loader, unlabel_loader, model,
                    train_criterion_distill, optimizer, ema_optimizer, epoch,
                    use_gpu, scheduler, is_mixsim)
                print(
                    'epoch {:03d}/{:03d} finished, loss: {:.3f}, loss_x: {:.3f}, loss_un: {:.3f}, avg_top1: {:.3f}%, avg_top5: {:.3f}%: distillation'
                    .format(epoch, opts.epochs, loss, loss_x, loss_u, avg_top1,
                            avg_top5))

            # scheduler.step()

            ######################################################################
            # For each weights=[0,0.5,1.0,1.5,2.0], save the best model with
            # best accuracy of (acc_top1 + weights * acc_top5).
            ######################################################################
            # print('start validation')
            acc_top1, acc_top5 = validation(opts, validation_loader, ema_model,
                                            epoch, use_gpu)
            is_best = acc_top1 > best_acc
            best_acc = max(acc_top1, best_acc)
            for w in range(4):
                is_weighted_best[w] = acc_top1 + (
                    (w + 1) * 0.5 * acc_top5) > best_weight_acc[w]
                best_weight_acc[w] = max(acc_top1 + ((w + 1) * 0.5 * acc_top5),
                                         best_weight_acc[w])
            if is_best:
                print(
                    'model achieved the best accuracy ({:.3f}%) - saving best checkpoint...'
                    .format(best_acc))
                if IS_ON_NSML:
                    nsml.save(opts.name + '_best')
                else:
                    torch.save(ema_model.state_dict(),
                               os.path.join('runs', opts.name + '_best'))
            for w in range(5):
                if (is_weighted_best[w]):
                    if IS_ON_NSML:
                        nsml.save(opts.name + '_{}w_best'.format(5 * (w + 1)))
                    else:
                        torch.save(
                            ema_model.state_dict(),
                            os.path.join(
                                'runs',
                                opts.name + '_{}w_best'.format(5 * (w + 1))))
            if (epoch + 1) % opts.save_epoch == 0:
                if IS_ON_NSML:
                    nsml.save(opts.name + '_e{}'.format(epoch))
                else:
                    torch.save(
                        ema_model.state_dict(),
                        os.path.join('runs', opts.name + '_e{}'.format(epoch)))
Esempio n. 11
0
    def fit(self,
            x: Union[Sequence[Union[Sequence, np.array]], np.array],
            y: Union[Sequence, np.array],
            n_epoch: int = 100,
            batch_size: int = 128,
            lr: float = 0.001,
            weight_decay: float = 0,
            instance_dropout: float = 0.95,
            verbose: bool = False) -> 'BaseNet':
        """
        Main fit method. fit data to model.  NOTE: his method works only with  subclasses
        Parameters
        ----------
        x: array-like
        If array: array of bags of shape Nmol*Nconf*Ndescr,
        where:  Nmol - number of molecules  in dataset, Nconf - number of conformers for a given molecule,
        Ndescr - length of descriptor string  for a conformer.
        If sequence:  sequence with bags, size of a bag  (Nconf) can vary (if varies, will be padded).
         Each entry of a bag is a descriptor
        vector for that conformer (that is not allowed to vary in length).
        y: array-like
        Labels for bags, array of shape Nmol (or sequence of length Nmol)
        n_epoch: int, default is 100
        Number of training epochs
        batch_size: int, default is 128
        Size of minibatch. TODO: implement check for minimal size
        lr: float, default 0.001
        Learning rate fo optimizer
        weight_decay: float, default is apply no L2 penalty (0)
        Value by which to multiply L2 penalty for optimizer
        instance_dropout: float, default is 0.95
        Randomly zeroes some of the instances with probability 1-instance_dropout (during training)
        using samples from a Bernoulli distribution.
        verbose: bool, default False

        Returns
        --------
        Network with trained weights        

        """

        self.instance_dropout = instance_dropout
        x, m = self.add_padding(x)
        x_train, x_val, y_train, y_val, m_train, m_val = train_val_split(
            x, y, m)
        if y_train.ndim == 1:  # convert 1d array into 2d ("column-vector")
            y_train = y_train.reshape(-1, 1)
        if y_val.ndim == 1:  # convert 1d array into 2d ("column-vector")
            y_val = y_val.reshape(-1, 1)
        if self.init_cuda:
            x_train, x_val, y_train, y_val, m_train, m_val  = x_train.cuda(), x_val.cuda(), y_train.cuda(), y_val.cuda(), \
                                                              m_train.cuda(), m_val.cuda()
        optimizer = optim.Yogi(self.parameters(),
                               lr=lr,
                               weight_decay=weight_decay)

        val_loss = []
        for epoch in range(n_epoch):
            mb = get_mini_batches(x_train,
                                  y_train,
                                  m_train,
                                  batch_size=batch_size)
            self.train()
            for x_mb, y_mb, m_mb in mb:
                loss = self.loss_batch(x_mb, y_mb, m_mb, optimizer=optimizer)

            self.eval()
            with torch.no_grad():
                loss = self.loss_batch(x_val, y_val, m_val, optimizer=None)
                val_loss.append(loss)

            min_loss_idx = val_loss.index(min(val_loss))
            if min_loss_idx == epoch:
                best_parameters = self.state_dict()
                if verbose:
                    print(epoch, loss)
        self.load_state_dict(best_parameters, strict=True)
        return self