예제 #1
0
def main(
    coach, attacker, valider, 
    trainloader, testloader, 
    start_epoch, info_path
):  
    from src.utils import save_checkpoint
    for epoch in range(start_epoch, opts.epochs):

        if epoch % SAVE_FREQ == 0:
            save_checkpoint(info_path, coach.model, coach.optimizer, coach.learning_policy, epoch)

            # train_accuracy, train_success = valider.evaluate(trainloader)
            # valid_accuracy, valid_success = valider.evaluate(testloader)
            # print(f"Train >>> [TA: {train_accuracy:.5f}]    [RA: {1 - train_success:.5f}]")
            # print(f"Test. >>> [TA: {valid_accuracy:.5f}]    [RA: {1 - valid_success:.5f}]")
            # writter.add_scalars("Accuracy", {"train":train_accuracy, "valid":valid_accuracy}, epoch)
            # writter.add_scalars("Success", {"train":train_success, "valid":valid_success}, epoch)

        running_loss = coach.adv_train(trainloader, attacker, epoch=epoch)
        writter.add_scalar("Loss", running_loss, epoch)

    train_accuracy, train_success = valider.evaluate(trainloader)
    valid_accuracy, valid_success = valider.evaluate(testloader)
    print(f"Train >>> [TA: {train_accuracy:.5f}]    [RA: {1 - train_success:.5f}]")
    print(f"Test. >>> [TA: {valid_accuracy:.5f}]    [RA: {1 - valid_success:.5f}]")
    writter.add_scalars("Accuracy", {"train":train_accuracy, "valid":valid_accuracy}, epoch)
    writter.add_scalars("Success", {"train":train_success, "valid":valid_success}, epoch)
예제 #2
0
def manage_checkpoints(colbert, optimizer, batch_idx):
    if batch_idx % 2000 == 0:
        save_checkpoint("colbert.dnn", 0, batch_idx, colbert, optimizer)

    if batch_idx in SAVED_CHECKPOINTS:
        save_checkpoint("colbert-" + str(batch_idx) + ".dnn", 0, batch_idx,
                        colbert, optimizer)
예제 #3
0
def manage_checkpoints(colbert, optimizer, batch_idx, output_dir):
    config = colbert.config
    checkpoint_dir = Path(output_dir)
    model_desc = f"colbert_hidden={config.hidden_size}_qlen={colbert.query_maxlen}_dlen={colbert.doc_maxlen}"
    if isinstance(colbert, SparseColBERT):
        n = "-".join([str(n) for n in colbert.n])
        k = "-".join([str(k) for k in colbert.k])
        model_desc += f"_sparse_n={n}_k={k}"
        if colbert.use_nonneg:
            model_desc += "_nonneg"
        if colbert.use_ortho:
            model_desc += "_ortho"
    else:
        model_desc += f"_dense"

    if batch_idx % 50000 == 0:
        save_checkpoint(
            checkpoint_dir / f"{model_desc}.last.dnn", 0, batch_idx, colbert, optimizer
        )

    if batch_idx in SAVED_CHECKPOINTS:
        save_checkpoint(
            checkpoint_dir / (f"{model_desc}.{batch_idx}.dnn"),
            0,
            batch_idx,
            colbert,
            optimizer,
        )
예제 #4
0
 def save_checkpoint(self, filepath=None):
     if filepath is None:
         filename = "step_{}.ckpt".format(self.global_step)
         filepath = os.path.join(self.checkpoint_dir, filename)
     state_dict = {
         "D":
         self.discriminator.state_dict(),
         "G":
         self.generator.state_dict(),
         'd_optimizer':
         self.d_optimizer.state_dict(),
         'g_optimizer':
         self.g_optimizer.state_dict(),
         "transition_step":
         self.transition_step,
         "is_transitioning":
         self.is_transitioning,
         "global_step":
         self.global_step,
         "total_time":
         self.total_time,
         "running_average_generator":
         self.running_average_generator.state_dict(),
         "latest_switch":
         self.latest_switch,
         "current_imsize":
         self.current_imsize,
         "transition_step":
         self.transition_step,
         "num_skipped_steps":
         self.num_skipped_steps
     }
     save_checkpoint(state_dict, filepath, max_keep=2)
예제 #5
0
    def test_save_and_load_checkpoint(self):
        model = torchvision.models.resnet18(pretrained=False)
        utils.save_checkpoint(model,
                              epoch=100,
                              filename='tmp.pth',
                              save_arch=True)

        loaded_model = utils.load_model('tmp.pth')

        torch.testing.assert_allclose(model.conv1.weight,
                                      loaded_model.conv1.weight)

        model.conv1.weight = nn.Parameter(torch.zeros_like(model.conv1.weight))
        model = utils.load_checkpoint('tmp.pth', model=model)['model']

        assert (model.conv1.weight != 0).any()
예제 #6
0
파일: AT.py 프로젝트: MTandHJ/roboc
def main(coach, attacker, valider, trainloader, testloader, start_epoch,
         info_path, log_path):
    from src.utils import save_checkpoint, TrackMeter, ImageMeter
    from src.dict2obj import Config
    acc_logger = Config(train=TrackMeter("Train"), valid=TrackMeter("Valid"))
    acc_logger.plotter = ImageMeter(*acc_logger.values(), title="Accuracy")

    rob_logger = Config(train=TrackMeter("Train"), valid=TrackMeter("Valid"))
    rob_logger.plotter = ImageMeter(*rob_logger.values(), title="Robustness")

    for epoch in range(start_epoch, opts.epochs):

        if epoch % SAVE_FREQ == 0:
            save_checkpoint(info_path, coach.model, coach.optimizer,
                            coach.learning_policy, epoch)

        if epoch % PRINT_FREQ == 0:
            evaluate(valider=valider,
                     trainloader=trainloader,
                     testloader=testloader,
                     acc_logger=acc_logger,
                     rob_logger=rob_logger,
                     writter=writter,
                     epoch=epoch)

        running_loss = coach.adv_train(trainloader,
                                       attacker,
                                       leverage=opts.leverage,
                                       epoch=epoch)
        writter.add_scalar("Loss", running_loss, epoch)

    evaluate(valider=valider,
             trainloader=trainloader,
             testloader=testloader,
             acc_logger=acc_logger,
             rob_logger=rob_logger,
             writter=writter,
             epoch=opts.epochs)

    acc_logger.plotter.plot()
    rob_logger.plotter.plot()
    acc_logger.plotter.save(writter)
    rob_logger.plotter.save(writter)
예제 #7
0
    def run(self, dataloader, epochs=1):
        print(">> Running trainer")
        for epoch in range(epochs):
            print(">>> Epoch %s" % epoch)
            for idx, (image,
                      target) in enumerate(tqdm.tqdm(dataloader, ascii=True)):
                image, target = image.to(self.device), target.to(self.device)
                self.optimizer.zero_grad()
                predict = self.net(image)
                loss = self.loss(predict, target.squeeze(1))
                loss.backward()
                self.losses.append(loss.item())
                self.optimizer.step()
                # if idx % 10 == 0:
                # print(">>> Loss: {}".format(np.mean(self.losses[-10:])))

                if self.config['DEBUG'] == True:
                    break
            print("Trainer epoch finished")
            save_checkpoint(self.net, {"epoch": epoch},
                            "{}-net.pth".format(epoch))
예제 #8
0
    def save_model(self,
                   epoch: int,
                   loss_val: float,
                   offset: int,
                   is_best: bool,
                   filename: str = None,
                   **kwargs) -> None:

        checkpoint_progress = tqdm(ncols=100,
                                   desc='Saving Checkpoint',
                                   position=offset)
        param = {
            'arch': self.args.model,
            'opt': self.args.optimizer,
            'model_state_dict': self.model_and_loss.module.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'epoch': epoch,
            'best_EPE': loss_val,
            'exp_key': self.experiment.get_key()
        }

        if self.lr_scheduler is not None:
            sch = {
                'scheduler': self.args.lr_scheduler,
                'lr_state_dict': self.lr_scheduler.state_dict()
            }
            param.update(sch)

        param.update(kwargs)  # for extra input arguments
        utils.save_checkpoint(param,
                              is_best,
                              self.args.save,
                              self.args.model,
                              filename=filename)
        checkpoint_progress.update(1)
        checkpoint_progress.close()
예제 #9
0
def train(args):
    """
    :param args:
    :return:
    """
    grammar = semQL.Grammar()
    sql_data, table_data, val_sql_data, val_table_data = utils.load_dataset(
        args.dataset, use_small=args.toy)

    model = IRNet(args, grammar)
    if args.cuda: model.cuda()

    # now get the optimizer
    optimizer_cls = eval('torch.optim.%s' % args.optimizer)
    optimizer = optimizer_cls(model.parameters(), lr=args.lr)
    print('Enable Learning Rate Scheduler: ', args.lr_scheduler)
    if args.lr_scheduler:
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[21, 41], gamma=args.lr_scheduler_gammar)
    else:
        scheduler = None

    print('Loss epoch threshold: %d' % args.loss_epoch_threshold)
    print('Sketch loss coefficient: %f' % args.sketch_loss_coefficient)

    if args.load_model:
        print('load pretrained model from %s' % (args.load_model))
        pretrained_model = torch.load(
            args.load_model, map_location=lambda storage, loc: storage)
        pretrained_modeled = copy.deepcopy(pretrained_model)
        for k in pretrained_model.keys():
            if k not in model.state_dict().keys():
                del pretrained_modeled[k]

        model.load_state_dict(pretrained_modeled)

    model.word_emb = utils.load_word_emb(args.glove_embed_path)
    # begin train

    model_save_path = utils.init_log_checkpoint_path(args)
    utils.save_args(args, os.path.join(model_save_path, 'config.json'))
    best_dev_acc = .0

    try:
        with open(os.path.join(model_save_path, 'epoch.log'), 'w') as epoch_fd:
            for epoch in tqdm.tqdm(range(args.epoch)):
                if args.lr_scheduler:
                    scheduler.step()
                epoch_begin = time.time()
                loss = utils.epoch_train(
                    model,
                    optimizer,
                    args.batch_size,
                    sql_data,
                    table_data,
                    args,
                    loss_epoch_threshold=args.loss_epoch_threshold,
                    sketch_loss_coefficient=args.sketch_loss_coefficient)
                epoch_end = time.time()
                json_datas, sketch_acc, acc, counts, corrects = utils.epoch_acc(
                    model,
                    args.batch_size,
                    val_sql_data,
                    val_table_data,
                    beam_size=args.beam_size)
                # acc = utils.eval_acc(json_datas, val_sql_data)

                if acc > best_dev_acc:
                    utils.save_checkpoint(
                        model, os.path.join(model_save_path,
                                            'best_model.model'))
                    best_dev_acc = acc
                utils.save_checkpoint(
                    model,
                    os.path.join(model_save_path, '{%s}_{%s}.model') %
                    (epoch, acc))

                log_str = 'Epoch: %d, Loss: %f, Sketch Acc: %f, Acc: %f, time: %f\n' % (
                    epoch + 1, loss, sketch_acc, acc, epoch_end - epoch_begin)
                tqdm.tqdm.write(log_str)
                epoch_fd.write(log_str)
                epoch_fd.flush()
    except Exception as e:
        # Save model
        utils.save_checkpoint(model,
                              os.path.join(model_save_path, 'end_model.model'))
        print(e)
        tb = traceback.format_exc()
        print(tb)
    else:
        utils.save_checkpoint(model,
                              os.path.join(model_save_path, 'end_model.model'))
        json_datas, sketch_acc, acc, counts, corrects = utils.epoch_acc(
            model,
            args.batch_size,
            val_sql_data,
            val_table_data,
            beam_size=args.beam_size)
        # acc = utils.eval_acc(json_datas, val_sql_data)

        print("Sketch Acc: %f, Acc: %f, Beam Acc: %f" % (
            sketch_acc,
            acc,
            acc,
        ))
def job(tuning, params_path, devices, resume, save_interval):
    global params
    if tuning:
        with open(params_path, 'r') as f:
            params = json.load(f)
        mode_str = 'tuning'
        setting = '_'.join(f'{tp}-{params[tp]}'
                           for tp in params['tuning_params'])
    else:
        mode_str = 'train'
        setting = ''

    exp_path = ROOT + f'experiments/{params["ex_name"]}/'
    os.environ['CUDA_VISIBLE_DEVICES'] = devices

    if resume is None:
        # C-AIRとABCIで整合性が取れるようにしている。
        params[
            'base_ckpt_path'] = f'experiments/v1only/ep4_augmentation-soft_epochs-5_loss-{params["loss"]}.pth'
        params[
            'clean_path'] = ROOT + f'input/clean/train19_cleaned_verifythresh{params["verifythresh"]}_freqthresh{params["freqthresh"]}.csv'
    else:
        params = utils.load_checkpoint(path=resume, params=True)['params']

    logger, writer = utils.get_logger(
        log_dir=exp_path + f'{mode_str}/log/{setting}',
        tensorboard_dir=exp_path + f'{mode_str}/tf_board/{setting}')

    if params['augmentation'] == 'soft':
        params['scale_limit'] = 0.2
        params['brightness_limit'] = 0.1
    elif params['augmentation'] == 'middle':
        params['scale_limit'] = 0.3
        params['shear_limit'] = 4
        params['brightness_limit'] = 0.1
        params['contrast_limit'] = 0.1
    else:
        raise ValueError

    train_transform, eval_transform = data_utils.build_transforms(
        scale_limit=params['scale_limit'],
        shear_limit=params['shear_limit'],
        brightness_limit=params['brightness_limit'],
        contrast_limit=params['contrast_limit'],
    )

    data_loaders = data_utils.make_train_loaders(
        params=params,
        data_root=ROOT + 'input/' + params['data'],
        train_transform=train_transform,
        eval_transform=eval_transform,
        scale='SS2',
        test_size=0,
        class_topk=params['class_topk'],
        num_workers=8)

    model = models.LandmarkNet(
        n_classes=params['class_topk'],
        model_name=params['model_name'],
        pooling=params['pooling'],
        loss_module=params['loss'],
        s=params['s'],
        margin=params['margin'],
        theta_zero=params['theta_zero'],
        use_fc=params['use_fc'],
        fc_dim=params['fc_dim'],
    ).cuda()

    criterion = nn.CrossEntropyLoss()
    optimizer = utils.get_optim(params, model)

    if resume is None:
        sdict = torch.load(ROOT + params['base_ckpt_path'])['state_dict']
        if params['loss'] == 'adacos':
            del sdict['final.W']  # remove fully-connected layer
        elif params['loss'] == 'softmax':
            del sdict['final.weight'], sdict[
                'final.bias']  # remove fully-connected layer
        else:
            del sdict['final.weight']  # remove fully-connected layer
        model.load_state_dict(sdict, strict=False)

        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=params['epochs'] * len(data_loaders['train']),
            eta_min=3e-6)
        start_epoch, end_epoch = (0,
                                  params['epochs'] - params['scaleup_epochs'])
    else:
        ckpt = utils.load_checkpoint(path=resume,
                                     model=model,
                                     optimizer=optimizer,
                                     epoch=True)
        model, optimizer, start_epoch = ckpt['model'], ckpt[
            'optimizer'], ckpt['epoch'] + 1
        end_epoch = params['epochs']

        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=params['epochs'] * len(data_loaders['train']),
            eta_min=3e-6,
            last_epoch=start_epoch * len(data_loaders['train']))

        setting += 'scaleup_' + resume.split('/')[-1].replace('.pth', '')

        data_loaders = data_utils.make_verified_train_loaders(
            params=params,
            data_root=ROOT + 'input/' + params['data'],
            train_transform=train_transform,
            eval_transform=eval_transform,
            scale='M2',
            test_size=0,
            num_workers=8)
        batch_norm.freeze_bn(model)

    if len(devices.split(',')) > 1:
        model = nn.DataParallel(model)

    for epoch in range(start_epoch, end_epoch):
        logger.info(f'Epoch {epoch}/{end_epoch}')

        # ============================== train ============================== #
        model.train(True)

        losses = utils.AverageMeter()
        prec1 = utils.AverageMeter()

        for i, (_, x, y) in tqdm(enumerate(data_loaders['train']),
                                 total=len(data_loaders['train']),
                                 miniters=None,
                                 ncols=55):
            x = x.to('cuda')
            y = y.to('cuda')

            outputs = model(x, y)
            loss = criterion(outputs, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            acc = metrics.accuracy(outputs, y)
            losses.update(loss.item(), x.size(0))
            prec1.update(acc, x.size(0))

            if i % 100 == 99:
                logger.info(
                    f'{epoch+i/len(data_loaders["train"]):.2f}epoch | {setting} acc: {prec1.avg}'
                )

        train_loss = losses.avg
        train_acc = prec1.avg

        writer.add_scalars('Loss', {'train': train_loss}, epoch)
        writer.add_scalars('Acc', {'train': train_acc}, epoch)
        writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)

        if (epoch + 1) == end_epoch or (epoch + 1) % save_interval == 0:
            output_file_name = exp_path + f'ep{epoch}_' + setting + '.pth'
            utils.save_checkpoint(path=output_file_name,
                                  model=model,
                                  epoch=epoch,
                                  optimizer=optimizer,
                                  params=params)

    model = model.module
    datasets = ('oxford5k', 'paris6k', 'roxford5k', 'rparis6k')
    results = eval_datasets(model,
                            datasets=datasets,
                            ms=True,
                            tta_gem_p=1.0,
                            logger=logger)

    if tuning:
        tuning_result = {}
        for d in datasets:
            if d in ('oxford5k', 'paris6k'):
                tuning_result[d] = results[d]
            else:
                for key in ['mapE', 'mapM', 'mapH']:
                    mapE, mapM, mapH, mpE, mpM, mpH, kappas = results[d]
                    tuning_result[d + '-' + key] = [eval(key)]
        utils.write_tuning_result(params, tuning_result,
                                  exp_path + 'tuning/results.csv')
예제 #11
0
파일: GAT.py 프로젝트: yangji9181/ALA
def train(args, model, data, log_dir, logger, optimizer=None):
    if optimizer is None:
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    t = time.time()
    best_acc, best_epoch = 0, 0
    count = 0
    model.train()
    features = data.features
    if args.sparse:
        adjs = data.adj
    else:
        adjs = torch.from_numpy(data.old_adj.todense())

    for epoch in range(1, args.epochs+1):
        losses = []
        optimizer.zero_grad()

        # If data is too big, use sampled data to train: sampled_features, sampled_adjs
        (sampled_features, sampled_adjs), (sampled_link_featuresL, sampled_link_featuresR), sampled_labels = data.sample('link', adjs)
        loss = model(sampled_features, sampled_adjs, sampled_link_featuresL, sampled_link_featuresR, sampled_labels)

        # If data is not too big, use whole data to train: features, adjs
        # loss = model(features, adjs, sampled_link_featuresL, sampled_link_featuresR, sampled_labels)

        loss.backward()
        optimizer.step()

        if epoch % args.log_every == 0:
            losses.append(loss.item())

        if epoch % args.log_every == 0:
            duration = time.time() - t
            msg = 'Epoch: {:04d} '.format(epoch)
            msg += 'loss: {:.4f}\t'.format(loss)
            logger.info(msg+' time: {:d}s '.format(int(duration)))

        if epoch % args.eval_every == 0:
            learned_embed = gensim.models.keyedvectors.Word2VecKeyedVectors(model.nembed)

            # If data is not too big, use whole data to get embeddings
            embedding = model.generate_embedding(features, adjs)
            learned_embed.add([str(i) for i in range(embedding.shape[0])], embedding)

            # If data is too big, Sample data to get embedding
            # for i in range(0, len(args.nodes), args.sample_embed):
            #     nodes = args.nodes[i:i+args.sample_embed]
            #     test_features = features[nodes]
            #     test_adjs = torch.zeros((len(nodes), len(nodes)))
            #     for i, n in enumerate(nodes):
            #         test_adjs[i] = adjs[n][nodes]
            #     embedding = model.generate_embedding(test_features, test_adjs)
            #     learned_embed.add([str(node) for node in nodes], embedding)

            # If data is too big, use only test data to get embedding
            # test_features = features[args.nodes]
            # test_adjs = torch.zeros((len(args.nodes), len(args.nodes)))
            # for i, n in enumerate(args.nodes):
            #     test_adjs[i] = adjs[n][args.nodes]
            # embedding = model.generate_embedding(test_features, test_adjs)
            # learned_embed.add([str(i) for i in args.nodes], embedding)


            train_acc, test_acc, std = evaluate(args, learned_embed, logger)
            duration = time.time() - t
            logger.info('Epoch: {:04d} '.format(epoch)+
                        'train_acc: {:.2f} '.format(train_acc)+
                        'test_acc: {:.2f} '.format(test_acc)+
                        'std: {:.2f} '.format(std)+
                        'time: {:d}s'.format(int(duration)))
            if test_acc > best_acc:
                best_acc = test_acc
                best_epoch = epoch
                save_checkpoint({
                    'args': args,
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, log_dir,
                    f'epoch{epoch}_time{int(duration):d}_trainacc{train_acc:.2f}_testacc{test_acc:.2f}_std{std:.2f}.pth.tar', logger, True)
                count = 0
            else:
                if args.early_stop:
                    count += args.eval_every
                if count >= args.patience:
                    logger.info('early stopped!')
                    break

    logger.info(f'best test acc={best_acc:.2f} @ epoch:{int(best_epoch):d}')
    if args.save_emb:
        learned_embed = gensim.models.keyedvectors.Word2VecKeyedVectors(model.nembed)
        # If data is not too big, use whole data to get embeddings
        embedding = model.generate_embedding(features, adjs)
        learned_embed.add([str(i) for i in range(embedding.shape[0])], embedding)

        # If data is too big, Sample data to get embedding
        # for i in range(0, len(args.nodes), args.sample_embed):
        #     nodes = args.nodes[i:i+args.sample_embed]
        #     test_features = features[nodes]
        #     test_adjs = torch.zeros((len(nodes), len(nodes)))
        #     for i, n in enumerate(nodes):
        #         test_adjs[i] = adjs[n][nodes]
        #     embedding = model.generate_embedding(test_features, test_adjs)
        #     learned_embed.add([str(node) for node in nodes], embedding)

        # If data is too big, use only test data to get embedding
        # test_features = features[args.nodes]
        # test_adjs = torch.zeros((len(args.nodes), len(args.nodes)))
        # for i, n in enumerate(args.nodes):
        #     test_adjs[i] = adjs[n][args.nodes]
        # embedding = model.generate_embedding(test_features, test_adjs)
        # learned_embed.add([str(i) for i in args.nodes], embedding)
        save_embedding(learned_embed, args.save_emb_file, binary=(os.path.splitext(args.save_emb_file)[1]== 'bin'))

    return best_acc
 def _save_checkpoint(self, current_bleu4, is_best):
     save_checkpoint(self.data_name, self.current_epoch,
                     self.epochs_since_improvement, self.encoder,
                     self.decoder, self.encoder_optimizer,
                     self.decoder_optimizer, self.save_dir, current_bleu4,
                     is_best)
예제 #13
0
def job(tuning, params_path, devices, resume, save_interval):

    global params
    if tuning:
        with open(params_path, 'r') as f:
            params = json.load(f)
        mode_str = 'tuning'
        setting = '_'.join(f'{tp}-{params[tp]}'
                           for tp in params['tuning_params'])
    else:
        mode_str = 'train'
        setting = ''

    exp_path = ROOT + f'experiments/{params["ex_name"]}/'
    os.environ['CUDA_VISIBLE_DEVICES'] = devices

    logger, writer = utils.get_logger(
        log_dir=exp_path + f'{mode_str}/log/{setting}',
        tensorboard_dir=exp_path + f'{mode_str}/tf_board/{setting}')
    train_transform, eval_transform = build_transforms(
        scale_range=params['scale_range'],
        brightness_range=params['brightness_range'])
    data_loaders = data_utils.make_train_loaders(
        params=params,
        data_root=ROOT + 'input/train2018',
        train_transform=train_transform,
        eval_transform=eval_transform,
        class_topk=params['class_topk'],
        num_workers=8)

    model = models.LandmarkFishNet(
        n_classes=params['class_topk'],
        model_name=params['model_name'],
        pooling_strings=params['pooling'].split(','),
        loss_module='arcface',
        s=30.0,
        margin=params['margin'],
        use_fc=params['use_fc'],
        fc_dim=params['fc_dim'],
    ).cuda()
    optimizer = utils.get_optim(params, model)
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=params['epochs'] * len(data_loaders['train']),
        eta_min=1e-6)

    if len(devices.split(',')) > 1:
        model = nn.DataParallel(model)
    if resume is not None:
        model, optimizer = utils.load_checkpoint(path=resume,
                                                 model=model,
                                                 optimizer=optimizer)

    for epoch in range(params['epochs']):
        logger.info(
            f'Epoch {epoch}/{params["epochs"]} | lr: {optimizer.param_groups[0]["lr"]}'
        )

        # ============================== train ============================== #
        model.train(True)

        losses = utils.AverageMeter()
        prec1 = utils.AverageMeter()

        for i, (_, x, y) in tqdm(enumerate(data_loaders['train']),
                                 total=len(data_loaders['train']),
                                 miniters=None,
                                 ncols=55):
            x = x.to('cuda')
            y = y.to('cuda')

            outputs = model(x, y)
            loss = criterion(outputs, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            acc = metrics.accuracy(outputs, y)
            losses.update(loss.item(), x.size(0))
            prec1.update(acc, x.size(0))

            if i % 100 == 99:
                logger.info(
                    f'{epoch+i/len(data_loaders["train"]):.2f}epoch | {setting} acc: {prec1.avg}'
                )

        train_loss = losses.avg
        train_acc = prec1.avg

        # ============================== validation ============================== #
        model.train(False)
        losses.reset()
        prec1.reset()

        for i, (_, x, y) in tqdm(enumerate(data_loaders['val']),
                                 total=len(data_loaders['val']),
                                 miniters=None,
                                 ncols=55):
            x = x.to('cuda')
            y = y.to('cuda')

            with torch.no_grad():
                outputs = model(x, y)
                loss = criterion(outputs, y)

            acc = metrics.accuracy(outputs, y)
            losses.update(loss.item(), x.size(0))
            prec1.update(acc, x.size(0))

        val_loss = losses.avg
        val_acc = prec1.avg

        logger.info(f'[Val] Loss: \033[1m{val_loss:.4f}\033[0m | '
                    f'Acc: \033[1m{val_acc:.4f}\033[0m\n')

        writer.add_scalars('Loss', {'train': train_loss}, epoch)
        writer.add_scalars('Acc', {'train': train_acc}, epoch)
        writer.add_scalars('Loss', {'val': val_loss}, epoch)
        writer.add_scalars('Acc', {'val': val_acc}, epoch)
        writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)

        if save_interval > 0:
            if (epoch +
                    1) == params['epochs'] or (epoch + 1) % save_interval == 0:
                output_file_name = exp_path + f'ep{epoch}_' + setting + '.pth'
                utils.save_checkpoint(path=output_file_name,
                                      model=model,
                                      epoch=epoch,
                                      optimizer=optimizer,
                                      params=params)

    if tuning:
        tuning_result = {}
        for key in ['train_loss', 'train_acc', 'val_loss', 'val_acc']:
            tuning_result[key] = [eval(key)]
        utils.write_tuning_result(params, tuning_result,
                                  exp_path + 'tuning/results.csv')
예제 #14
0
        train_loss = losses.avg
        train_acc = prec1.avg

        print("[{:5d}] => loss={:.9f}, acc={:.9f}; lr={:.9f}.".format(epoch, train_loss, train_acc))


        writer.add_scalars('Loss', {'train': train_loss}, epoch)
        writer.add_scalars('Acc', {'train': train_acc}, epoch)
        writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)

        if (epoch + 1) == end_epoch or (epoch + 1) % save_interval == 0:
            output_file_name = exp_path + f'ep{epoch}_' + setting + '.pth'
            utils.save_checkpoint(path=output_file_name,
                                  model=model,
                                  epoch=epoch,
                                  optimizer=optimizer,
                                  params=params)


        # ============================== validation ============================== #
        model.eval()
        val_losses = utils.AverageMeter()
        val_prec1 = utils.AverageMeter()

        with torch.no_grad():
            for i, (_, x, y) in tqdm(enumerate(data_loaders['val']),
                                 total=len(data_loaders['val']),
                                 miniters=None, ncols=55):

                if num_GPU>0:
예제 #15
0
def job(tuning, params_path, devices, resume, save_interval):
    global params
    if tuning:
        with open(params_path, 'r') as f:
            params = json.load(f)
        mode_str = 'tuning'
        setting = '_'.join(f'{tp}-{params[tp]}'
                           for tp in params['tuning_params'])
    else:
        mode_str = 'train'
        setting = ''

    # パラメーターを変えるときにseedも変えたい(seed averagingの効果を期待)
    seed = sum(ord(_) for _ in str(params.values()))
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False

    exp_path = ROOT + f'experiments/{params["ex_name"]}/'
    os.environ['CUDA_VISIBLE_DEVICES'] = devices

    logger, writer = utils.get_logger(
        log_dir=exp_path + f'{mode_str}/log/{setting}',
        tensorboard_dir=exp_path + f'{mode_str}/tf_board/{setting}')

    if params['augmentation'] == 'soft':
        params['scale_limit'] = 0.2
        params['brightness_limit'] = 0.1
    elif params['augmentation'] == 'middle':
        params['scale_limit'] = 0.3
        params['shear_limit'] = 4
        params['brightness_limit'] = 0.1
        params['contrast_limit'] = 0.1
    else:
        raise ValueError

    train_transform, eval_transform = data_utils.build_transforms(
        scale_limit=params['scale_limit'],
        shear_limit=params['shear_limit'],
        brightness_limit=params['brightness_limit'],
        contrast_limit=params['contrast_limit'],
    )

    data_loaders = data_utils.make_train_loaders(
        params=params,
        data_root=ROOT + 'input/' + params['data'],
        train_transform=train_transform,
        eval_transform=eval_transform,
        scale='S',
        test_size=0,
        class_topk=params['class_topk'],
        num_workers=8)

    model = models.LandmarkNet(
        n_classes=params['class_topk'],
        model_name=params['model_name'],
        pooling=params['pooling'],
        loss_module=params['loss'],
        s=params['s'],
        margin=params['margin'],
        theta_zero=params['theta_zero'],
        use_fc=params['use_fc'],
        fc_dim=params['fc_dim'],
    ).cuda()
    optimizer = utils.get_optim(params, model)
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=params['epochs'] * len(data_loaders['train']),
        eta_min=3e-6)
    start_epoch = 0

    if len(devices.split(',')) > 1:
        model = nn.DataParallel(model)

    for epoch in range(start_epoch, params['epochs']):

        logger.info(
            f'Epoch {epoch}/{params["epochs"]} | lr: {optimizer.param_groups[0]["lr"]}'
        )

        # ============================== train ============================== #
        model.train(True)

        losses = utils.AverageMeter()
        prec1 = utils.AverageMeter()

        for i, (_, x, y) in tqdm(enumerate(data_loaders['train']),
                                 total=len(data_loaders['train']),
                                 miniters=None,
                                 ncols=55):
            x = x.to('cuda')
            y = y.to('cuda')

            outputs = model(x, y)
            loss = criterion(outputs, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            acc = metrics.accuracy(outputs, y)
            losses.update(loss.item(), x.size(0))
            prec1.update(acc, x.size(0))

            if i % 100 == 99:
                logger.info(
                    f'{epoch+i/len(data_loaders["train"]):.2f}epoch | {setting} acc: {prec1.avg}'
                )

        train_loss = losses.avg
        train_acc = prec1.avg

        writer.add_scalars('Loss', {'train': train_loss}, epoch)
        writer.add_scalars('Acc', {'train': train_acc}, epoch)
        writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)

        if (epoch + 1) == params['epochs'] or (epoch + 1) % save_interval == 0:
            output_file_name = exp_path + f'ep{epoch}_' + setting + '.pth'
            utils.save_checkpoint(path=output_file_name,
                                  model=model,
                                  epoch=epoch,
                                  optimizer=optimizer,
                                  params=params)

    model = model.module
    datasets = ('roxford5k', 'rparis6k')
    results = eval_datasets(model,
                            datasets=datasets,
                            ms=False,
                            tta_gem_p=1.0,
                            logger=logger)

    if tuning:
        tuning_result = {}
        for d in datasets:
            for key in ['mapE', 'mapM', 'mapH']:
                mapE, mapM, mapH, mpE, mpM, mpH, kappas = results[d]
                tuning_result[d + '-' + key] = [eval(key)]
        utils.write_tuning_result(params, tuning_result,
                                  exp_path + 'tuning/results.csv')
예제 #16
0
def main(seed, pretrain, resume, evaluate, print_runtime, epochs, disable_tqdm,
         visdom_port, ckpt_path, make_plot, cuda):
    device = torch.device("cuda" if cuda else "cpu")
    callback = None if visdom_port is None else VisdomLogger(port=visdom_port)
    if seed is not None:
        random.seed(seed)
        torch.manual_seed(seed)
        cudnn.deterministic = True
    torch.cuda.set_device(0)
    # create model
    print("=> Creating model '{}'".format(
        ex.current_run.config['model']['arch']))
    model = torch.nn.DataParallel(get_model()).cuda()

    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    optimizer = get_optimizer(model)

    if pretrain:
        pretrain = os.path.join(pretrain, 'checkpoint.pth.tar')
        if os.path.isfile(pretrain):
            print("=> loading pretrained weight '{}'".format(pretrain))
            checkpoint = torch.load(pretrain)
            model_dict = model.state_dict()
            params = checkpoint['state_dict']
            params = {k: v for k, v in params.items() if k in model_dict}
            model_dict.update(params)
            model.load_state_dict(model_dict)
        else:
            print(
                '[Warning]: Did not find pretrained model {}'.format(pretrain))

    if resume:
        resume_path = ckpt_path + '/checkpoint.pth.tar'
        if os.path.isfile(resume_path):
            print("=> loading checkpoint '{}'".format(resume_path))
            checkpoint = torch.load(resume_path)
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            # scheduler.load_state_dict(checkpoint['scheduler'])
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume_path, checkpoint['epoch']))
        else:
            print('[Warning]: Did not find checkpoint {}'.format(resume_path))
    else:
        start_epoch = 0
        best_prec1 = -1

    cudnn.benchmark = True

    # Data loading code
    evaluator = Evaluator(device=device, ex=ex)
    if evaluate:
        print("Evaluating")
        results = evaluator.run_full_evaluation(model=model,
                                                model_path=ckpt_path,
                                                callback=callback)
        #MYMOD
        #,model_tag='best',
        #shots=[5],
        #method="tim-gd")
        return results

    # If this line is reached, then training the model
    trainer = Trainer(device=device, ex=ex)
    scheduler = get_scheduler(optimizer=optimizer,
                              num_batches=len(trainer.train_loader),
                              epochs=epochs)
    tqdm_loop = warp_tqdm(list(range(start_epoch, epochs)),
                          disable_tqdm=disable_tqdm)
    for epoch in tqdm_loop:
        # Do one epoch
        trainer.do_epoch(model=model,
                         optimizer=optimizer,
                         epoch=epoch,
                         scheduler=scheduler,
                         disable_tqdm=disable_tqdm,
                         callback=callback)

        # Evaluation on validation set
        prec1 = trainer.meta_val(model=model,
                                 disable_tqdm=disable_tqdm,
                                 epoch=epoch,
                                 callback=callback)
        print('Meta Val {}: {}'.format(epoch, prec1))
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        if not disable_tqdm:
            tqdm_loop.set_description('Best Acc {:.2f}'.format(best_prec1 *
                                                               100.))

        # Save checkpoint
        save_checkpoint(state={
            'epoch': epoch + 1,
            'arch': ex.current_run.config['model']['arch'],
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict()
        },
                        is_best=is_best,
                        folder=ckpt_path)
        if scheduler is not None:
            scheduler.step()

    # Final evaluation on test set
    results = evaluator.run_full_evaluation(model=model, model_path=ckpt_path)
    return results
예제 #17
0
        'model': model,
        'dataset': train_dataset,
        'optimizer': optimizer,
        'writer': writer
    })

    with torch.no_grad():
        val_loss, acc, report = eval({
            'model': model,
            'dataset': val_dataset,
            'optimizer': optimizer,
        })
    accuracy.append(acc)

    save_checkpoint(model,
                    extra={
                        'lb': acc,
                        'epoch': state['epoch'] + e
                    },
                    checkpoint='last_clf.pth')

    writer.add_scalars('classifier/losses', {
        'train_loss': train_loss,
        'val_loss': val_loss
    })
    writer.add_scalar('classifier/metric', acc)

    print("Epoch: %d, Train: %.3f, Val: %.3f, Acc: %.3f" %
          (e, train_loss, val_loss, acc))
    print(report)
예제 #18
0
def main():
    """Main training program."""

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    #     if args.load_huggingface:
    #         args.make_vocab_size_divisible_by = 1

    # Pytorch distributed.
    initialize_distributed(args)
    if torch.distributed.get_rank() == 0:
        print('Pretrain GPT3 model')
        print_args(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # Data stuff.
    train_data, val_data, test_data, args.vocab_size, args.eod_token, tokenizer = get_train_val_test_data(args)

    # Model, optimizer, and learning rate.
    model, optimizer, lr_scheduler = setup_model_and_optimizer(args)

    # Resume data loader if necessary.
    if args.resume_dataloader:
        if train_data is not None:
            train_data.batch_sampler.start_iter = args.iteration % len(train_data)
            print_rank_0(f"Resume train set from iteration {train_data.batch_sampler.start_iter}")
        if val_data is not None:
            start_iter_val = (args.train_iters // args.save_interval) * args.eval_interval
            val_data.batch_sampler.start_iter = start_iter_val % len(val_data)
    if train_data is not None:
        train_data_iterator = iter(train_data)
    else:
        train_data_iterator = None

    iteration = 0
    if args.train_iters > 0:
        if args.do_train:
            iteration, skipped = train(model, optimizer,
                                       lr_scheduler,
                                       train_data_iterator,
                                       val_data,
                                       timers,
                                       args,
                                       tokenizer)

        if args.do_valid:
            prefix = 'the end of training for val data'
            # val_loss, val_ppl
            _ = evaluate_and_print_results(prefix, iter(val_data) if val_data else None,
                                           model, args, timers, False)

    if args.save and iteration != 0:
        save_checkpoint(iteration, model, optimizer, lr_scheduler, args, deepspeed=DEEPSPEED_WRAP and args.deepspeed)

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, iter(test_data) if test_data else None,
                                   model, args, timers, True)
예제 #19
0
def main():
    TRAIN_PATH = os.path.join(DIRECTORY, 'train')
    if N_FOLDS == 1:
        folds = split_train_val()
    else:
        folds = kfold_split()

    board = SummaryWriter()
    loaded_models = [
        exp_load_model("fold-{}.pth".format(f))
        for f in range(min(N_FOLDS, TAKE_FOLDS))
    ]
    accuracy = [loaded_model[3]['lb_acc'] for loaded_model in loaded_models]
    for e in range(TRAINING_EPOCH):
        accuracy = []
        for i, (train_list, val_list) in enumerate(folds[:TAKE_FOLDS]):
            model, optimizer, scheduler, state = loaded_models[i]
            start_epoch, best_lb_acc, rlr_iter = state['epoch'], state[
                'lb_acc'], state.get('iter', 0)

            dataset = TGSSaltDatasetAug(TRAIN_PATH, train_list, aug=True)
            dataset_val = TGSSaltDatasetAug(TRAIN_PATH, val_list)

            train_loss = train_iter(dataset, model, optimizer)
            with torch.no_grad():
                val_loss, lb_acc, _ = eval(dataset_val, model)
            scheduler.step(lb_acc, epoch=e)
            # scheduler.step()
            # scheduler.batch_step()
            accuracy.append(lb_acc)

            # Chekpoints
            if lb_acc > best_lb_acc:
                save_checkpoint(model,
                                optimizer=optimizer,
                                extra={
                                    'epoch': start_epoch + e,
                                    'lb_acc': lb_acc
                                },
                                checkpoint='fold-%d.pth' % i)
                state['lb_acc'] = lb_acc

            # if e % CYCLES == 0:
            #     save_checkpoint(model, optimizer=optimizer, extra={
            #         'epoch': start_epoch + e,
            #         'lb_acc': lb_acc
            #     }, checkpoint='cycle-%d-%.3f.pth' % (e % CYCLES, lb_acc))
            # el
            if e % 30 == 0 or e == TRAINING_EPOCH - 1:
                save_checkpoint(model,
                                optimizer=optimizer,
                                extra={
                                    'epoch': start_epoch + e,
                                    'lb_acc': lb_acc
                                },
                                checkpoint='ep%s-%.3f.pth' %
                                (start_epoch + e, lb_acc))
            else:
                save_checkpoint(model,
                                optimizer=optimizer,
                                extra={
                                    'epoch': start_epoch + e,
                                    'lb_acc': lb_acc
                                },
                                checkpoint='last.pth')

            # Tensorboard
            board.add_scalars('seresnet/losses', {
                'train_loss': train_loss,
                'val_loss': val_loss
            }, e)

            board.add_scalar('seresnet/lb_acc', lb_acc, e)

            log = "Epoch: %d, Fold %d, Train: %.3f, Val: %.3f, LB: %.3f (Best: %.3f)" % (
                start_epoch + e, i, train_loss, val_loss, lb_acc, best_lb_acc)

            print(log)

        print("Mean accuracy %.3f , Variance %.3f" %
              (np.mean(accuracy), np.var(accuracy)))

    test_dataset, test_file_list = get_test_dataset(DIRECTORY)
    test_predictions = []
    for (model, _, _, _) in loaded_models:
        model.eval()
        with torch.no_grad():
            all_predictions_stacked = test_tta(model, test_dataset)
        test_predictions.append(all_predictions_stacked)

    fold_mean_prediciton = np.mean(test_predictions, axis=0)
    binary_prediction = (fold_mean_prediciton > BIN_THRESHOLD).astype(int)
    # predictions = binary_prediction
    predictions = clear_small_masks(binary_prediction)
    submit = build_submission(predictions, test_file_list)
    submit.to_csv('submitM%.3fV%.3f.csv' %
                  (np.mean(accuracy), np.std(accuracy)),
                  index=False)
    board.close()
예제 #20
0
def train(model, optimizer, lr_scheduler,
          train_data_iterator, val_data, timers, args, tokenizer):
    """Train the model."""

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_lm_loss = 0.0

    # Iterations.
    iteration = args.iteration
    skipped_iters = 0
    tb_writer = None
    if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
        tb_writer = SummaryWriter(log_dir=args.logging_dir)

    timers('interval time').start()
    report_memory_flag = True
    is_master = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
    print('--Start training loop--')
    train_start = True
    # avg_lm_loss = 1e6
    while iteration < args.train_iters:
        timers('data loader').start()
        sample = next(train_data_iterator) if (train_data_iterator is not None) else None
        timers('data loader').stop()

        if train_start and is_master:
            batch_text = f"\n\Iteration {iteration} start sample: {tokenizer.decode(sample[0, :200])}"
            tb_writer.add_text('train_start', batch_text, iteration)

        lm_loss, skipped_iter = train_step(sample,
                                           model,
                                           optimizer,
                                           lr_scheduler,
                                           args, timers, tokenizer, iteration, tb_writer)
        skipped_iters += skipped_iter
        iteration += 1
        train_start = False

        # Update losses.
        total_lm_loss += lm_loss.data.detach().float()

        # Logging.
        if is_master and iteration % args.log_interval == 0:
            learning_rate = optimizer.param_groups[0]['lr']
            avg_lm_loss = total_lm_loss.item() / args.log_interval
            ppl = math.exp(avg_lm_loss)
            elapsed_time = timers('interval time').elapsed()
            samples = args.log_interval * mpu.get_data_parallel_world_size() * args.batch_size
            tokens = samples * args.seq_length
            log_string = ' iteration {:8d}/{:8d} |'.format(iteration, args.train_iters)
            log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(elapsed_time * 1000.0 / args.log_interval)
            log_string += ' learning rate {:.3E} |'.format(learning_rate)
            log_string += ' lm loss {:.4f} |'.format(avg_lm_loss)
            log_string += ' perplexity {:.4f} |'.format(ppl)
            scalars = {
                'Loss/loss': avg_lm_loss,
                'Loss/perplexity': ppl,
                'learning_rate': learning_rate,
                'Speed/iteration_time_ms': (elapsed_time * 1000.0 / args.log_interval),
                'Speed/samples_per_sec': (samples / elapsed_time),
                'Speed/tokens_per_sec': (tokens / elapsed_time),
                'Speed/tokens_per_step': (tokens / args.log_interval),
                'Speed/seen_tokens': iteration * (tokens / args.log_interval)
            }
            if args.fp16:
                lscale = optimizer.cur_scale if DEEPSPEED_WRAP and args.deepspeed else optimizer.loss_scale
                log_string += ' loss scale {:.1f} |'.format(lscale)
                scalars['lscale'] = lscale
            print_rank_0(log_string)
            for k, v in scalars.items():
                tb_writer.add_scalar(k, v, iteration)

            if ppl < 3:
                # generate only when model is relatively good
                prefix = 'Бразильские ученые открыли редкий вид карликовых единорогов, обитающих на западе Ютландии'
                model.eval()
                with torch.no_grad():
                    text = generate(model, tokenizer, prefix, 128)
                model.train()
                tb_writer.add_text('sample', text, iteration)

            if args.log_memory:
                log_memory_usage(tb_writer, iteration)
            total_lm_loss = 0.0
            if report_memory_flag:
                report_memory('after {} iterations'.format(iteration))
                report_memory_flag = False
            if USE_TORCH_DDP:
                timers.log(['forward', 'backward', 'optimizer', 'data loader'], normalizer=args.log_interval)
            else:
                timers.log(['forward', 'backward', 'allreduce', 'optimizer', 'data loader'],
                           normalizer=args.log_interval)
        # Checkpointing
        if args.save and args.save_interval and iteration % args.save_interval == 0:
            save_checkpoint(iteration, model, optimizer, lr_scheduler, args, deepspeed=DEEPSPEED_WRAP and args.deepspeed)

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            val_loss, val_ppl = evaluate_and_print_results(
                prefix, iter(val_data) if val_data else None, model, args, timers, False)
            if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
                scalars = {'val_loss': val_loss, 'val_perplexity': val_ppl}
                for k, v in scalars.items():
                    tb_writer.add_scalar(k, v, iteration)

        if args.exit_interval and iteration % args.exit_interval == 0:
            torch.distributed.barrier()
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            rank = torch.distributed.get_rank()
            print('rank: {} | time: {} | exiting the program at iteration {}'.
                  format(rank, time_str, iteration), flush=True)
            exit()

    return iteration, skipped_iters
예제 #21
0
    if not args.evaluate_only:
        best_loss = sys.maxint

        for epoch in range(1, args.epochs + 1):
            train_loss = train(epoch)
            val_losses = val(epoch)
            val_loss_sum = sum(val_losses)

            is_best = val_loss_sum < best_loss
            best_loss = min(val_loss_sum, best_loss)

            save_checkpoint(
                {
                    'state_dict': infernet.state_dict(),
                    'cmd_line_args': args,
                },
                is_best,
                folder=args.out_dir)

    print('loading best performing model')
    checkpoint = torch.load(os.path.join(args.out_dir, 'model_best.pth.tar'))
    state_dict = checkpoint['state_dict']

    # NOTE: this includes out-of-sample
    plane_lengths = np.arange(1, 20, 1)
    plane_degrees = np.arange(5, 85, 5)
    planes = list(product(plane_lengths, plane_degrees))
    n_planes = len(planes)

    # out-of-sample models -- measure how well we can do
예제 #22
0
def main():
    # parse command line argument and generate config dictionary
    config = parse_args()
    logger.info(json.dumps(config, indent=2))

    run_config = config['run_config']
    optim_config = config['optim_config']

    # Code for saving in the correct place
    all_arguments = {}
    for key in config.keys():
        all_arguments.update(config[key])

    run_config['save_name'] = run_config['save_name'].format(**all_arguments)
    print('Saving in ' + run_config['save_name'])
    # End code for saving in the right place

    if run_config['test_config']:
        sys.exit(0)

    # TensorBoard SummaryWriter
    if run_config['tensorboard']:
        writer = SummaryWriter(run_config['outdir'])
    else:
        writer = None

    # create output directory
    outdir = pathlib.Path(run_config['outdir'])
    outdir.mkdir(exist_ok=True, parents=True)

    # save config as json file in output directory
    outpath = outdir / 'config.json'
    with open(outpath, 'w') as fout:
        json.dump(config, fout, indent=2)

    # load data loaders
    train_loader, test_loader = get_loader(config['data_config'])

    # set random seed (this was moved after the data loading because the data
    # loader might have a random seed)
    seed = run_config['seed']
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    epoch_seeds = np.random.randint(np.iinfo(np.int32).max // 2,
                                    size=optim_config['epochs'])

    # load model
    logger.info('Loading model...')
    model = utils.load_model(config['model_config'])
    n_params = sum([param.view(-1).size()[0] for param in model.parameters()])
    logger.info('n_params: {}'.format(n_params))

    if run_config['count_params']:
        # this option means just count the number of parameters, then move on
        sys.exit(0)

    if run_config['fp16'] and not run_config['use_amp']:
        model.half()
        for layer in model.modules():
            if isinstance(layer, nn.BatchNorm2d):
                layer.float()

    device = torch.device(run_config['device'])
    if device.type == 'cuda' and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.to(device)
    logger.info('Done')

    train_criterion, test_criterion = utils.get_criterion(
        config['data_config'])

    # create optimizer
    if optim_config['no_weight_decay_on_bn']:
        params = [
            {
                'params': [
                    param for name, param in model.named_parameters()
                    if 'bn' not in name
                ]
            },
            {
                'params': [
                    param for name, param in model.named_parameters()
                    if 'bn' in name
                ],
                'weight_decay':
                0
            },
        ]
    else:
        params = model.parameters()
    optim_config['steps_per_epoch'] = len(train_loader)
    optimizer, scheduler = utils.create_optimizer(params, optim_config)

    # for mixed-precision
    amp_handle = apex.amp.init(
        enabled=run_config['use_amp']) if is_apex_available else None

    # run test before start training
    if run_config['test_first']:
        test(0, model, test_criterion, test_loader, run_config, writer)

    state = {
        'config': config,
        'state_dict': None,
        'optimizer': None,
        'epoch': 0,
        'accuracy': 0,
        'best_accuracy': 0,
        'best_epoch': 0,
    }
    epoch_logs = []
    for epoch, seed in zip(range(1, optim_config['epochs'] + 1), epoch_seeds):
        np.random.seed(seed)
        # train
        train_log = train(epoch, model, optimizer, scheduler, train_criterion,
                          train_loader, config, writer, amp_handle)

        # test
        test_log = test(epoch, model, test_criterion, test_loader, run_config,
                        writer)

        epoch_log = train_log.copy()
        epoch_log.update(test_log)
        epoch_logs.append(epoch_log)
        utils.save_epoch_logs(epoch_logs, outdir)

        # update state dictionary
        state = update_state(state, epoch, epoch_log['test']['accuracy'],
                             model, optimizer)

        # save model
        utils.save_checkpoint(state, outdir)
    """
    Upload to bucket code
    """

    from google.cloud import storage
    import os

    client = storage.Client()
    bucket = client.get_bucket('ramasesh-bucket-1')
    filenames = os.listdir(outdir)

    for filename in filenames:
        print('Processing file: ' + filename)

        blob = bucket.blob(run_config['save_name'] + filename)
        blob.upload_from_filename(str(outdir) + '/' + filename)
    """
예제 #23
0
    def start(self):
        # Training Start
        start_time = time.time()

        if self.args.evaluate:
            # Loading evaluating model
            print('Loading evaluating model ...')
            checkpoint = U.load_checkpoint(self.model_name)
            self.model.module.load_state_dict(checkpoint['model'])
            self.optimizer.module.load_state_dict(checkpoint['optimizer'])
            print('Successful!\n')

            # Start evaluating
            print('Starting evaluating ...')
            self.model.module.eval()
            acc = self.eval()
            print('Finish evaluating!')
            print('Best accuracy: {:2.2f}%, Total time:{:.4f}s'.format(
                acc,
                time.time() - start_time))

        else:
            # Resuming
            start_epoch, best_acc = 0, 0
            if self.args.resume:
                print('Loading checkpoint ...')
                checkpoint = U.load_checkpoint()
                self.model.module.load_state_dict(checkpoint['model'])
                self.optimizer.module.load_state_dict(checkpoint['optimizer'])
                start_epoch = checkpoint['epoch']
                best_acc = checkpoint['best']
                print('Successful!\n')

            # Start training
            print('Starting training ...')
            self.model.module.train()
            for epoch in range(start_epoch, self.args.max_epoch):

                # Adjusting learning rate
                self.adjust_lr(epoch)

                # Training
                acc = self.train(epoch)
                print(
                    'Epoch: {}/{}, Training accuracy: {:2.2f}%, Training time: {:.4f}s\n'
                    .format(epoch + 1, self.args.max_epoch, acc,
                            time.time() - start_time))

                # Evaluating
                is_best = False
                if (epoch +
                        1) > self.args.adjust_lr[-1] and (epoch + 1) % 2 == 0:
                    print('Evaluating for epoch {} ...'.format(epoch + 1))
                    self.model.module.eval()
                    acc = self.eval()
                    print(
                        'Epoch: {}/{}, Evaluating accuracy: {:2.2f}%, Evaluating time: {:.4f}s\n'
                        .format(epoch + 1, self.args.max_epoch, acc,
                                time.time() - start_time))
                    self.model.module.train()
                    if acc > best_acc:
                        best_acc = acc
                        is_best = True

                # Saving model
                U.save_checkpoint(self.model.module.state_dict(),
                                  self.optimizer.module.state_dict(),
                                  epoch + 1, best_acc, is_best,
                                  self.model_name)
            print('Finish training!')
            print('Best accuracy: {:2.2f}%, Total time: {:.4f}s'.format(
                best_acc,
                time.time() - start_time))
예제 #24
0
 )
 dev_loss, dev_acc = pipeline.test(
     model=model,
     dm=dm,
     loss_criterion=criterion,
     args=args,
     is_dev=True
 )
 if dev_acc > best_dev_acc:
     print('New best model: {} vs {}'.format(dev_acc, best_dev_acc))
     best_dev_acc = dev_acc
     save_checkpoint(
         state={
             'args': args,
             'epoch': epoch + 1,
             'state_dict': model.state_dict(),
             'acc': dev_acc,
             'best_acc': best_dev_acc,
             'optimizer': optimizer.state_dict(),
             'lr': state['lr']
         }, is_best=True)
 print('Saving to checkpoint')
 save_checkpoint(
     state={
         'args': args,
         'epoch': epoch + 1,
         'state_dict': model.state_dict(),
         'acc': dev_acc,
         'best_acc': best_dev_acc,
         'optimizer': optimizer.state_dict(),
         'lr': state['lr']
     }, is_best=False)
        end_time = time.time()
        train_losses[epoch] = train_loss
        train_times[epoch] = start_time - end_time
        
        if not args.no_test:
            test_loss = test(epoch)
            test_losses[epoch] = test_loss
            is_best = test_loss < best_loss
            best_loss = min(test_loss, best_loss)
        else:
            is_best = train_loss < best_loss
            best_loss = min(train_loss, best_loss)

        save_checkpoint({
            'model_state_dict': model.state_dict(),
            'epoch': epoch,
            'args': args,
        }, is_best, folder=args.out_dir)

        np.save(os.path.join(args.out_dir, 'train_losses.npy'), train_losses)
        np.save(os.path.join(args.out_dir, 'train_times.npy'), train_times)

        if not args.no_test:
            np.save(os.path.join(args.out_dir, 'test_losses.npy'),  test_losses)

    for checkpoint_name in ['checkpoint.pth.tar', 'model_best.pth.tar']:
        checkpoint = torch.load(os.path.join(args.out_dir, checkpoint_name))
        model.load_state_dict(checkpoint['model_state_dict'])

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
        end_time = time.time()
        train_losses[epoch] = train_loss
        train_times[epoch] = start_time - end_time
        
        if not args.no_test:
            test_loss = test(epoch)
            test_losses[epoch] = test_loss
            is_best = test_loss < best_loss
            best_loss = min(test_loss, best_loss)
        else:
            is_best = train_loss < best_loss
            best_loss = min(train_loss, best_loss)

        save_checkpoint({
            'model_state_dict': model.state_dict(),
            'epoch': epoch,
            'args': args,
            'total_iters': len(train_loader) * args.epochs,
        }, is_best, folder=args.out_dir)

        np.save(os.path.join(args.out_dir, 'train_losses.npy'), train_losses)
        np.save(os.path.join(args.out_dir, 'train_times.npy'), train_times)

        if not args.no_test:
            np.save(os.path.join(args.out_dir, 'test_losses.npy'),  test_losses)

    for checkpoint_name in ['checkpoint.pth.tar', 'model_best.pth.tar']:
        checkpoint = torch.load(os.path.join(args.out_dir, checkpoint_name))
        model.load_state_dict(checkpoint['model_state_dict'])

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
예제 #27
0
        with torch.no_grad():

            model.forward(inputs)

            model.backward_G()  # needed to calculate losses

            if hasattr(model, 'dis_params'):
                model.backward_D()

        logs.update_losses('test')

    logs.update_tboard(epoch)

    # Save weights
    if not epoch % opt.save_every_epoch:

        utils.save_checkpoint(model, epoch)

        if hasattr(model, 'gen_params'):
            torch.save(
                opt_G.state_dict(),
                os.path.join(model.weights_path, '%d_opt_G.pkl' % epoch))

        if hasattr(model, 'dis_params'):
            torch.save(
                opt_D.state_dict(),
                os.path.join(model.weights_path, '%d_opt_D.pkl' % epoch))

logs.close()
utils.save_checkpoint(model, 'latest')
예제 #28
0
파일: Graphsage.py 프로젝트: yangji9181/ALA
def train(args, model, Data, log_dir, logger, optimizer=None):
    if optimizer is None:
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)

    t = time.time()
    best_acc, best_epoch = 0, 0
    count = 0
    model.train()

    for epoch in range(1, args.epochs + 1):
        losses = []
        optimizer.zero_grad()
        (sampled_features, sampled_adj,
         prior), sampled_labels = Data.sample('link')
        loss = model(sampled_features, sampled_adj, sampled_labels)
        loss.backward()
        optimizer.step()

        if epoch % args.log_every == 0:
            losses.append(loss.item())

        if epoch % args.log_every == 0:
            duration = time.time() - t
            msg = 'Epoch: {:04d} '.format(epoch)
            msg += 'loss: {:.4f}\t'.format(loss)
            logger.info(msg + ' time: {:d}s '.format(int(duration)))

        if epoch % args.eval_every == 0:
            learned_embed = gensim.models.keyedvectors.Word2VecKeyedVectors(
                model.nembed)
            for i in range(0, len(args.nodes), args.sample_embed):
                nodes = args.nodes[i:i + args.sample_embed]
                features, adj, _ = Data.sample_subgraph(nodes, False)
                embedding = model.generate_embedding(features, adj)
                learned_embed.add([str(node) for node in nodes], embedding)
            train_acc, test_acc, std = evaluate(args, learned_embed, logger)
            duration = time.time() - t
            logger.info('Epoch: {:04d} '.format(epoch) +
                        'train_acc: {:.2f} '.format(train_acc) +
                        'test_acc: {:.2f} '.format(test_acc) +
                        'std: {:.2f} '.format(std) +
                        'time: {:d}s'.format(int(duration)))
            if test_acc > best_acc:
                best_acc = test_acc
                best_epoch = epoch
                save_checkpoint(
                    {
                        'args': args,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }, log_dir,
                    f'epoch{epoch}_time{int(duration):d}_trainacc{train_acc:.2f}_testacc{test_acc:.2f}_std{std:.2f}.pth.tar',
                    logger, True)
                count = 0
            else:
                if args.early_stop:
                    count += args.eval_every
                if count >= args.patience:
                    logger.info('early stopped!')
                    break

    logger.info(f'best test acc={best_acc:.2f} @ epoch:{int(best_epoch):d}')

    if args.save_emb:
        learned_embed = gensim.models.keyedvectors.Word2VecKeyedVectors(
            model.nembed)
        for i in range(0, len(args.nodes), args.sample_embed):
            nodes = args.nodes[i:i + args.sample_embed]
            features, adj, _ = Data.sample_subgraph(nodes, False)
            embedding = model.generate_embedding(features, adj)
            learned_embed.add([str(node) for node in nodes], embedding)
        save_embedding(learned_embed,
                       args.save_emb_file,
                       binary=(os.path.splitext(
                           args.save_emb_file)[1] == 'bin'))

    return best_acc
예제 #29
0
                f"--> eval - acc:{acc}\tprec:{prec}\trec:{rec}\tf1:{f_score}")

            eval_epochs.append(epoch)
            eval_accuracy.append(acc)
            eval_precision.append(prec)
            eval_recall.append(rec)
            eval_f_score.append(f_score)

            if f_score > best_f1:
                best_f1 = f_score
                state_dict = model.state_dict()
                state_dict = dict([(k, v.cpu())
                                   for k, v in state_dict.items()])

                save_checkpoint(path_=path.join(settings.get("ckp_dir"),
                                                exp_name),
                                state=state_dict,
                                is_best=False)

    train_timeseries = [
        create_timeseries(t, n, train_epochs)
        for t, n in [(train_accuracy, "accuracy"), (losses, "loss")]
    ]
    plot_scalars(path.join(settings.get("run_dir"), exp_name, "train.png"),
                 train_timeseries)

    test_timeseries = [
        create_timeseries(t, n, eval_epochs)
        for t, n in [(eval_accuracy, "accuracy"), (
            eval_precision,
            "precision"), (eval_recall, "recall"), (eval_f_score, "f_score")]
    ]
예제 #30
0
def main(args):
    """ The main training function.

    Only works for single node (be it single or multi-GPU)

    Parameters
    ----------
    args :
        Parsed arguments
    """
    # setup
    ngpus = torch.cuda.device_count()
    if ngpus == 0:
        raise RuntimeWarning("This will not be able to run on CPU only")

    print(f"Working with {ngpus} GPUs")
    if args.optim.lower() == "ranger":
        # No warm up if ranger optimizer
        args.warm = 0

    current_experiment_time = datetime.now().strftime('%Y%m%d_%T').replace(":", "")
    args.exp_name = f"{'debug_' if args.debug else ''}{current_experiment_time}_" \
                    f"_fold{args.fold if not args.full else 'FULL'}" \
                    f"_{args.arch}_{args.width}" \
                    f"_batch{args.batch_size}" \
                    f"_optim{args.optim}" \
                    f"_{args.optim}" \
                    f"_lr{args.lr}-wd{args.weight_decay}_epochs{args.epochs}_deepsup{args.deep_sup}" \
                    f"_{'fp16' if not args.no_fp16 else 'fp32'}" \
                    f"_warm{args.warm}_" \
                    f"_norm{args.norm_layer}{'_swa' + str(args.swa_repeat) if args.swa else ''}" \
                    f"_dropout{args.dropout}" \
                    f"_warm_restart{args.warm_restart}" \
                    f"{'_' + args.com.replace(' ', '_') if args.com else ''}"
    args.save_folder = pathlib.Path(f"./runs/{args.exp_name}")
    args.save_folder.mkdir(parents=True, exist_ok=True)
    args.seg_folder = args.save_folder / "segs"
    args.seg_folder.mkdir(parents=True, exist_ok=True)
    args.save_folder = args.save_folder.resolve()
    save_args(args)
    t_writer = SummaryWriter(str(args.save_folder))

    # Create model
    print(f"Creating {args.arch}")

    model_maker = getattr(models, args.arch)

    model = model_maker(
        4, 3,
        width=args.width, deep_supervision=args.deep_sup,
        norm_layer=get_norm_layer(args.norm_layer), dropout=args.dropout)

    print(f"total number of trainable parameters {count_parameters(model)}")

    if args.swa:
        # Create the average model
        swa_model = model_maker(
            4, 3,
            width=args.width, deep_supervision=args.deep_sup,
            norm_layer=get_norm_layer(args.norm_layer))
        for param in swa_model.parameters():
            param.detach_()
        swa_model = swa_model.cuda()
        swa_model_optim = WeightSWA(swa_model)

    if ngpus > 1:
        model = torch.nn.DataParallel(model).cuda()
    else:
        model = model.cuda()
    print(model)
    model_file = args.save_folder / "model.txt"
    with model_file.open("w") as f:
        print(model, file=f)

    criterion = EDiceLoss().cuda()
    metric = criterion.metric
    print(metric)

    rangered = False  # needed because LR scheduling scheme is different for this optimizer
    if args.optim == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay, eps=1e-4)


    elif args.optim == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9,
                                    nesterov=True)

    elif args.optim == "adamw":
        print(f"weight decay argument will not be used. Default is 11e-2")
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    elif args.optim == "ranger":
        optimizer = Ranger(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        rangered = True

    # optionally resume from a checkpoint
    if args.resume:
        reload_ckpt(args, model, optimizer)

    if args.debug:
        args.epochs = 2
        args.warm = 0
        args.val = 1

    if args.full:
        train_dataset, bench_dataset = get_datasets(args.seed, args.debug, full=True)

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True,
            num_workers=args.workers, pin_memory=False, drop_last=True)

        bench_loader = torch.utils.data.DataLoader(
            bench_dataset, batch_size=1, num_workers=args.workers)

    else:

        train_dataset, val_dataset, bench_dataset = get_datasets(args.seed, args.debug, fold_number=args.fold)

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True,
            num_workers=args.workers, pin_memory=False, drop_last=True)

        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=max(1, args.batch_size // 2), shuffle=False,
            pin_memory=False, num_workers=args.workers, collate_fn=determinist_collate)

        bench_loader = torch.utils.data.DataLoader(
            bench_dataset, batch_size=1, num_workers=args.workers)
        print("Val dataset number of batch:", len(val_loader))

    print("Train dataset number of batch:", len(train_loader))

    # create grad scaler
    scaler = GradScaler()

    # Actual Train loop

    best = np.inf
    print("start warm-up now!")
    if args.warm != 0:
        tot_iter_train = len(train_loader)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lambda cur_iter: (1 + cur_iter) / (tot_iter_train * args.warm))

    patients_perf = []

    if not args.resume:
        for epoch in range(args.warm):
            ts = time.perf_counter()
            model.train()
            training_loss = step(train_loader, model, criterion, metric, args.deep_sup, optimizer, epoch, t_writer,
                                 scaler, scheduler, save_folder=args.save_folder,
                                 no_fp16=args.no_fp16, patients_perf=patients_perf)
            te = time.perf_counter()
            print(f"Train Epoch done in {te - ts} s")

            # Validate at the end of epoch every val step
            if (epoch + 1) % args.val == 0 and not args.full:
                model.eval()
                with torch.no_grad():
                    validation_loss = step(val_loader, model, criterion, metric, args.deep_sup, optimizer, epoch,
                                           t_writer, save_folder=args.save_folder,
                                           no_fp16=args.no_fp16)

                t_writer.add_scalar(f"SummaryLoss/overfit", validation_loss - training_loss, epoch)

    if args.warm_restart:
        print('Total number of epochs should be divisible by 30, else it will do odd things')
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 30, eta_min=1e-7)
    else:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               args.epochs + 30 if not rangered else round(
                                                                   args.epochs * 0.5))
    print("start training now!")
    if args.swa:
        # c = 15, k=3, repeat = 5
        c, k, repeat = 30, 3, args.swa_repeat
        epochs_done = args.epochs
        reboot_lr = 0
        if args.debug:
            c, k, repeat = 2, 1, 2

    for epoch in range(args.start_epoch + args.warm, args.epochs + args.warm):
        try:
            # do_epoch for one epoch
            ts = time.perf_counter()
            model.train()
            training_loss = step(train_loader, model, criterion, metric, args.deep_sup, optimizer, epoch, t_writer,
                                 scaler, save_folder=args.save_folder,
                                 no_fp16=args.no_fp16, patients_perf=patients_perf)
            te = time.perf_counter()
            print(f"Train Epoch done in {te - ts} s")

            # Validate at the end of epoch every val step
            if (epoch + 1) % args.val == 0 and not args.full:
                model.eval()
                with torch.no_grad():
                    validation_loss = step(val_loader, model, criterion, metric, args.deep_sup, optimizer,
                                           epoch,
                                           t_writer,
                                           save_folder=args.save_folder,
                                           no_fp16=args.no_fp16, patients_perf=patients_perf)

                t_writer.add_scalar(f"SummaryLoss/overfit", validation_loss - training_loss, epoch)

                if validation_loss < best:
                    best = validation_loss
                    model_dict = model.state_dict()
                    save_checkpoint(
                        dict(
                            epoch=epoch, arch=args.arch,
                            state_dict=model_dict,
                            optimizer=optimizer.state_dict(),
                            scheduler=scheduler.state_dict(),
                        ),
                        save_folder=args.save_folder, )

                ts = time.perf_counter()
                print(f"Val epoch done in {ts - te} s")

            if args.swa:
                if (args.epochs - epoch - c) == 0:
                    reboot_lr = optimizer.param_groups[0]['lr']

            if not rangered:
                scheduler.step()
                print("scheduler stepped!")
            else:
                if epoch / args.epochs > 0.5:
                    scheduler.step()
                    print("scheduler stepped!")

        except KeyboardInterrupt:
            print("Stopping training loop, doing benchmark")
            break

    if args.swa:
        swa_model_optim.update(model)
        print("SWA Model initialised!")
        for i in range(repeat):
            optimizer = torch.optim.Adam(model.parameters(), args.lr / 2, weight_decay=args.weight_decay)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, c + 10)
            for swa_epoch in range(c):
                # do_epoch for one epoch
                ts = time.perf_counter()
                model.train()
                swa_model.train()
                current_epoch = epochs_done + i * c + swa_epoch
                training_loss = step(train_loader, model, criterion, metric, args.deep_sup, optimizer,
                                     current_epoch, t_writer,
                                     scaler, no_fp16=args.no_fp16, patients_perf=patients_perf)
                te = time.perf_counter()
                print(f"Train Epoch done in {te - ts} s")

                t_writer.add_scalar(f"SummaryLoss/train", training_loss, current_epoch)

                # update every k epochs and val:
                print(f"cycle number: {i}, swa_epoch: {swa_epoch}, total_cycle_to_do {repeat}")
                if (swa_epoch + 1) % k == 0:
                    swa_model_optim.update(model)
                    if not args.full:
                        model.eval()
                        swa_model.eval()
                        with torch.no_grad():
                            validation_loss = step(val_loader, model, criterion, metric, args.deep_sup, optimizer,
                                                   current_epoch,
                                                   t_writer, save_folder=args.save_folder, no_fp16=args.no_fp16)
                            swa_model_loss = step(val_loader, swa_model, criterion, metric, args.deep_sup, optimizer,
                                                  current_epoch,
                                                  t_writer, swa=True, save_folder=args.save_folder,
                                                  no_fp16=args.no_fp16)

                        t_writer.add_scalar(f"SummaryLoss/val", validation_loss, current_epoch)
                        t_writer.add_scalar(f"SummaryLoss/swa", swa_model_loss, current_epoch)
                        t_writer.add_scalar(f"SummaryLoss/overfit", validation_loss - training_loss, current_epoch)
                        t_writer.add_scalar(f"SummaryLoss/overfit_swa", swa_model_loss - training_loss, current_epoch)
                scheduler.step()
        epochs_added = c * repeat
        save_checkpoint(
            dict(
                epoch=args.epochs + epochs_added, arch=args.arch,
                state_dict=swa_model.state_dict(),
                optimizer=optimizer.state_dict()
            ),
            save_folder=args.save_folder, )
    else:
        save_checkpoint(
            dict(
                epoch=args.epochs, arch=args.arch,
                state_dict=model.state_dict(),
                optimizer=optimizer.state_dict()
            ),
            save_folder=args.save_folder, )

    try:
        df_individual_perf = pd.DataFrame.from_records(patients_perf)
        print(df_individual_perf)
        df_individual_perf.to_csv(f'{str(args.save_folder)}/patients_indiv_perf.csv')
        reload_ckpt_bis(f'{str(args.save_folder)}/model_best.pth.tar', model)
        generate_segmentations(bench_loader, model, t_writer, args)
    except KeyboardInterrupt:
        print("Stopping right now!")