示例#1
0
def main(rootpath="data/"):
	logger=Logger(rootpath+"log.txt","ExtAddApp.py",True)
	configService=ConfigService(rootpath)
	
	if not appex.is_running_extension():
		print('This script is intended to be run from the sharing extension.')
		return
		
	url = appex.get_url()
	if not url:
		console.alert("Error","No input URL found.",'OK', hide_cancel_button=True)
		
		if(configService.getLog().getData()==1):
			logger.error("No input URL found.")
		return
	
	console.hud_alert("正在抓取数据,请等待...","success")
		
	appSerVice=AppService(rootpath)	
	res=appSerVice.addApp(url)
		
	if(res.equal(ResultEnum.APP_UPDATE)):
		console.hud_alert("应用更新成功!",'success')
	elif(res.equal(ResultEnum.SUCCESS)):
		console.hud_alert("应用添加成功!",'success')
	else:
		console.hud_alert(res.getInfo(),'error')
示例#2
0
def main(config):
    loader = Loader(config)
    base = Base(config, loader)
    make_dirs(base.output_path)
    make_dirs(base.save_logs_path)
    make_dirs(base.save_model_path)
    logger = Logger(os.path.join(base.save_logs_path, 'log.txt'))
    logger(config)

    if config.mode == 'train':
        if config.resume_train_epoch >= 0:
            base.resume_model(config.resume_train_epoch)
            start_train_epoch = config.resume_train_epoch
        else:

            start_train_epoch = 0

        if config.auto_resume_training_from_lastest_step:
            root, _, files = os_walk(base.save_model_path)
            if len(files) > 0:
                indexes = []
                for file in files:
                    indexes.append(int(
                        file.replace('.pkl', '').split('_')[-1]))
                indexes = sorted(list(set(indexes)), reverse=False)
                base.resume_model(indexes[-1])
                start_train_epoch = indexes[-1]
                logger(
                    'Time: {}, automatically resume training from the latest step (model {})'
                    .format(time_now(), indexes[-1]))

        for current_epoch in range(start_train_epoch,
                                   config.total_train_epoch):
            base.save_model(current_epoch)

            if current_epoch < config.use_graph:
                _, result = train_meta_learning(base, loader)
                logger('Time: {}; Epoch: {}; {}'.format(
                    time_now(), current_epoch, result))
                if current_epoch + 1 >= 1 and (current_epoch + 1) % 40 == 0:
                    mAP, CMC = test(config, base, loader)
                    logger(
                        'Time: {}; Test on Target Dataset: {}, \nmAP: {} \n Rank: {}'
                        .format(time_now(), config.target_dataset, mAP, CMC))
            else:
                _, result = train_with_graph(config, base, loader)
                logger('Time: {}; Epoch: {}; {}'.format(
                    time_now(), current_epoch, result))
                if current_epoch + 1 >= 1 and (current_epoch + 1) % 5 == 0:
                    mAP, CMC = test_with_graph(config, base, loader)
                    logger(
                        'Time: {}; Test on Target Dataset: {}, \nmAP: {} \n Rank: {}'
                        .format(time_now(), config.target_dataset, mAP, CMC))

    elif config.mode == 'test':
        base.resume_model(config.resume_test_model)
        mAP, CMC = test_with_graph(config, base, loader)
        logger('Time: {}; Test on Target Dataset: {}, \nmAP: {} \n Rank: {}'.
               format(time_now(), config.target_dataset, mAP, CMC))
示例#3
0
    def __init__(self, rootpath="data/"):
        self.rootpath = rootpath

        dbpath = self.rootpath + "database.db"
        self.mPriceController = PriceController(dbpath)

        self.logger = Logger(self.rootpath + "log.txt", "PriceService.py",
                             True)
示例#4
0
    def __init__(self, rootpath="../data/"):
        self.rootpath = rootpath

        dbpath = self.rootpath + "database.db"
        self.mConfigController = ConfigController(dbpath)

        self.logger = Logger(self.rootpath + "log.txt", "ConfigService.py",
                             True)
示例#5
0
 def __init__(self, WebInsMgrRef, name):
     self.WebInsMgrRef = WebInsMgrRef
     self.Logger = Logger(name, 'WebInstance')
     self.instance_name = name
     self.LastAccessTime = str()
     self.LastEpoch = 0
     self.LongLoad = False
     self.InstanceKeepAliveIns = type(InstanceKeepAlive)
     self.KeepAlivePythonProcess = None
     self.port = int(self.WebInsMgrRef.basePort) + int(self.instance_name)
示例#6
0
    def __init__(self, rootpath="../data/"):
        self.rootpath = rootpath

        dbpath = self.rootpath + "database.db"
        self.mAppController = AppController(dbpath)
        self.mPriceService = PriceService(rootpath)
        self.mConfigService = ConfigService(rootpath)

        self.mNotification = Notification("AppWishList")

        self.logger = Logger(self.rootpath + "log.txt", "AppService.py", True)
示例#7
0
def eva_a_phi(phi):
    na, nnh, nh, nw = phi

    # choose a dataset to train (mscoco, flickr8k, flickr30k)
    dataset = 'mscoco'
    data_dir = osp.join(DATA_ROOT, dataset)

    from model.ra import Model
    # settings
    mb = 64  # mini-batch size
    lr = 0.0002  # learning rate
    # nh = 512  # size of LSTM's hidden size
    # nnh = 512  # hidden size of attention mlp
    # nw = 512  # size of word embedding vector
    # na = 512  # size of the region features after dimensionality reduction
    name = 'ra'  # model name, just setting it to 'ra' is ok. 'ra'='region attention'
    vocab_freq = 'freq5'  # use the vocabulary that filtered out words whose frequences are less than 5

    print '... loading data {}'.format(dataset)
    train_set = Reader(batch_size=mb, data_split='train', vocab_freq=vocab_freq, stage='train',
                       data_dir=data_dir, feature_file='features_30res.h5', topic_switch='off') # change 0, 1000, 82783
    valid_set = Reader(batch_size=1, data_split='val', vocab_freq=vocab_freq, stage='val',
                       data_dir=data_dir, feature_file='features_30res.h5',
                       caption_switch='off', topic_switch='off') # change 0, 10, 5000

    npatch, nimg = train_set.features.shape[1:]
    nout = len(train_set.vocab)
    save_dir = '{}-nnh{}-nh{}-nw{}-na{}-mb{}-V{}'.\
        format(dataset.lower(), nnh, nh, nw, na, mb, nout)
    save_dir = osp.join(SAVE_ROOT, save_dir)

    model_file, m = find_last_snapshot(save_dir, resume_training=False)
    os.system('cp model/ra.py {}/'.format(save_dir))
    logger = Logger(save_dir)
    logger.info('... building')
    model = Model(name=name, nimg=nimg, nnh=nnh, nh=nh, na=na, nw=nw, nout=nout, npatch=npatch, model_file=model_file)

    # start training
    bs = BeamSearch([model], beam_size=1, num_cadidates=100, max_length=20)
    best = train(model, bs, train_set, valid_set, save_dir, lr,
                 display=100, starting=m, endding=20, validation=2000, life=10, logger=logger) # change dis1,100; va 2,2000; life 0,10;
    average_models(best=best, L=6, model_dir=save_dir, model_name=name+'.h5') # L 1, 6

    # evaluation
    np.save('data_dir', data_dir)
    np.save('save_dir', save_dir)

    os.system('python valid_time.py')

    scores = np.load('scores.npy')
    running_time = np.load('running_time.npy')
    print 'cider:', scores[-1], 'B1-4,C:', scores, 'running time:', running_time

    return scores, running_time
示例#8
0
 def train(self, exp_name, params, logname):
     self.logger = Logger(params, logname)
     sgd = optimizers.SGD(lr=params["learning_rate"],
                          momentum=0.0,
                          decay=0.0,
                          nesterov=False)
     self.model.compile(optimizer=sgd, loss='mean_squared_error')
     if exp_name == "lowest":
         for i in range(params["repeats"]):
             pair = self.play_games(params["num_batch"])
             self.train_on_played_games_lowest(pair[0], pair[1], params)
     self.model.save(self.logger.filename + "-model")
示例#9
0
def main(config):

    # init loaders and base
    loaders = ReIDLoaders(config)
    base = Base(config)

    # make directions
    make_dirs(base.output_path)

    # init logger
    logger = Logger(os.path.join(config.output_path, 'log.txt'))
    logger(config)

    assert config.mode in ['train', 'test', 'visualize']
    if config.mode == 'train':  # train mode

        # automatically resume model from the latest one
        if config.auto_resume_training_from_lastest_steps:
            start_train_epoch = base.resume_last_model()

        # main loop
        for current_epoch in range(start_train_epoch,
                                   config.total_train_epochs):
            # save model
            base.save_model(current_epoch)
            # train
            _, results = train_an_epoch(config, base, loaders, current_epoch)
            logger('Time: {};  Epoch: {};  {}'.format(time_now(),
                                                      current_epoch, results))

        # test
        base.save_model(config.total_train_epochs)
        mAP, CMC, pres, recalls, thresholds = test(config, base, loaders)
        logger('Time: {}; Test Dataset: {}, \nmAP: {} \nRank: {}'.format(
            time_now(), config.test_dataset, mAP, CMC))
        plot_prerecall_curve(config, pres, recalls, thresholds, mAP, CMC,
                             'none')

    elif config.mode == 'test':  # test mode
        base.resume_from_model(config.resume_test_model)
        mAP, CMC, pres, recalls, thresholds = test(config, base, loaders)
        logger('Time: {}; Test Dataset: {}, \nmAP: {} \nRank: {}'.format(
            time_now(), config.test_dataset, mAP, CMC))
        logger(
            'Time: {}; Test Dataset: {}, \nprecision: {} \nrecall: {}\nthresholds: {}'
            .format(time_now(), config.test_dataset, mAP, CMC, pres, recalls,
                    thresholds))
        plot_prerecall_curve(config, pres, recalls, thresholds, mAP, CMC,
                             'none')

    elif config.mode == 'visualize':  # visualization mode
        base.resume_from_model(config.resume_visualize_model)
        visualize(config, base, loaders)
def main(config):

    # init loaders and base
    loaders = ReIDLoaders(config)
    base = Base(config)

    # make directions
    make_dirs(base.output_path)

    # init logger
    logger = Logger(os.path.join(config.output_path, 'log.txt'))
    logger(config)

    assert config.mode in ['train', 'test', 'visualize']
    if config.mode == 'train':  # train mode

        # automatically resume model from the latest one
        if config.auto_resume_training_from_lastest_steps:
            print('resume', base.output_path)
            start_train_epoch = base.resume_last_model()
        #start_train_epoch = 0

        # main loop
        for current_epoch in range(start_train_epoch,
                                   config.total_train_epochs + 1):
            # save model
            base.save_model(current_epoch)
            # train
            base.lr_scheduler.step(current_epoch)
            _, results = train_an_epoch(config, base, loaders)
            logger('Time: {};  Epoch: {};  {}'.format(time_now(),
                                                      current_epoch, results))

        # test
        base.save_model(config.total_train_epochs)
        mAP, CMC = test(config, base, loaders)
        logger('Time: {}; Test Dataset: {}, \nmAP: {} \nRank: {}'.format(
            time_now(), config.test_dataset, mAP, CMC))

    elif config.mode == 'test':  # test mode
        base.resume_from_model(config.resume_test_model)
        mAP, CMC = test(config, base, loaders)
        logger('Time: {}; Test Dataset: {}, \nmAP: {} \nRank: {} with len {}'.
               format(time_now(), config.test_dataset, mAP, CMC, len(CMC)))

    elif config.mode == 'visualize':  # visualization mode
        base.resume_from_model(config.resume_visualize_model)
        visualize(config, base, loaders)
示例#11
0
    def __init__(self, args):
        now_time = datetime.datetime.strftime(datetime.datetime.now(),
                                              '%m%d-%H%M%S')
        args.cur_dir = os.path.join(args.exp_dir, now_time)
        args.log_path = os.path.join(args.cur_dir, 'train.log')
        args.best_model_path = os.path.join(args.cur_dir, 'best_model.pth')

        self.args = args
        mkdir(self.args.exp_dir)
        mkdir(self.args.cur_dir)
        self.log = Logger(self.args.log_path, level='debug').logger
        self.log.critical("args: \n{}".format(to_str_args(self.args)))

        self.train_loader = torch.utils.data.DataLoader(
            dataset=CUB200Dataset(root=self.args.root, train=True),
            batch_size=self.args.batch_size,
            num_workers=self.args.num_workers,
            pin_memory=self.args.pin_memory,
            shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(
            dataset=CUB200Dataset(root=self.args.root, train=False),
            batch_size=self.args.batch_size,
            num_workers=self.args.num_workers,
            pin_memory=self.args.pin_memory,
            shuffle=False)

        self.model = torchvision.models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(in_features=self.model.fc.in_features,
                                  out_features=self.args.num_classes)
        self.model.cuda()

        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(
            params=self.model.parameters(), lr=self.args.lr
        ) if self.args.optim_type == 'Adam' else torch.optim.SGD(
            params=self.model.parameters(),
            lr=self.args.lr,
            momentum=self.args.momentum,
            weight_decay=self.args.decay)

        self.log.critical("model: \n{}".format(self.model))
        self.log.critical("torchsummary: \n{}".format(
            summary(model=self.model, input_size=(3, 224, 224))))
        self.log.critical("criterion: \n{}".format(self.criterion))
        self.log.critical("optimizer: \n{}".format(self.optimizer))
示例#12
0
def train(model, beam_searcher, train_set, valid_set, save_dir, lr,
          display=100, starting=0, endding=20, validation=2000, life=10, logger=None):
    """
    display:    output training infomation every 'display' mini-batches
    starting:   the starting snapshots, > 0 when resuming training
    endding:    the least training snapshots
    validation: evaluate on validation set every 'validation' mini-batches
    life:       increase of endding when finds better model
    """
    train_func, _ = adam_optimizer(model, lr=lr)
    print '... training'
    logger = Logger(save_dir) if logger is None else logger
    timer = Timer()
    loss = 0
    imb = starting * validation
    best = -1
    best_snapshot = -1
    timer.tic()
    while imb < endding*validation:
        imb += 1
        x = train_set.iterate_batch()
        loss += train_func(*x)[0] / display
        if imb % display == 0:
            logger.info('snapshot={}, iter={},  loss={:.6f},  time={:.1f} sec'.format(imb/validation, imb, loss, timer.toc()))
            timer.tic()
            loss = 0
        if imb % validation == 0:
            saving_index = imb/validation
            model.save_to_dir(save_dir, saving_index)
            try:
                scores = validate(beam_searcher, valid_set, logger)
                if scores[3] > best:
                    best = scores[3]
                    best_snapshot = saving_index
                    endding = max(saving_index+life, endding)
                logger.info('    ---- this Bleu-4 = [%.3f],   best Bleu-4 = [%.3f], endding -> %d' % \
                            (scores[3], best, endding))
            except OSError:
                print '[Ops!! OS Error]'

    logger.info('Training done, best snapshot is [%d]' % best_snapshot)
    return best_snapshot
def validate(beam_searcher, dataset, logger=None, res_file=None):
    if logger is None:
        logger = Logger(None)
    # generating captions
    all_candidates = []
    for i in xrange(dataset.n_image):
        data = dataset.iterate_batch()  # data: id, img, scene...
        sent = beam_searcher.generate(data[1:])
        cap = ' '.join([dataset.vocab[word] for word in sent])
        print '[{}], id={}, \t\t {}'.format(i, data[0], cap)
        all_candidates.append({'image_id': data[0], 'caption': cap})

    if res_file is None:
        res_file = 'tmp.json'
    json.dump(all_candidates, open(res_file, 'w'))
    gt_file = osp.join(dataset.data_dir, 'captions_'+dataset.data_split+'.json')
    scores = evaluate(gt_file, res_file, logger)
    if res_file == 'tmp.json':
        os.system('rm -rf %s' % res_file)

    return scores
示例#14
0
def main(rootpath="data/"):
    logger = Logger(rootpath + "log.txt", "AutoUpdateApp.py.py", True)
    configService = ConfigService(rootpath)

    if (configService.getLog().getData() == 1):
        logger.info("开始自动更新:")

    if (not isConnected("http://www.baidu.com")):
        if (configService.getLog().getData() == 1):
            logger.error("网络连接超时!\n")
        return

    serv = AppService(rootpath)

    res = serv.updateAllApps()

    if (not res.equal(ResultEnum.SUCCESS)):
        if (configService.getLog().getData() == 1):
            logger.error("自动更新出错: " + res.toString())
    else:
        if (configService.getLog().getData() == 1):
            logger.info("自动更新完成。\n")
示例#15
0
 def __init__(self, session, WebInsMgrRef):
     import subprocess
     self.WebInsMgrRef = WebInsMgrRef
     self.Logger = Logger(session, 'InstanceKeepAlive')
     self.session = session
     self.LastTime = self.CheckTime()
     self.port = int(WebInsMgrRef.basePort) + int(self.session)
     self.KeepAlivePython = [
         'import time,subprocess',
         'session = %d' % (int(session)), 'while True:',
         '    time.sleep(200)',
         "    subp = subprocess.Popen(['curl', 'http://10.216.35.20/GetEpoch/%d' %(int(session))])"
     ]
     subprocess.Popen(
         ['rm', '-f',
          'session%dkeepalive.py' % (int(self.port))])
     time.sleep(2)
     to_open = 'session%dkeepalive.py' % (self.port)
     pyKeepAlive = open(to_open, 'w')
     for eachLine in self.KeepAlivePython:
         pyKeepAlive.write('%s\n' % (eachLine))
     pyKeepAlive.close()
     self.KeepAlivePythonProcess = subprocess.Popen(
         ['python', 'session%dkeepalive.py' % (self.port)])
示例#16
0
def main(device, args):
    train_loader = torch.utils.data.DataLoader(
        dataset=get_dataset(transform=get_aug(train=True, **args.aug_kwargs),
                            train=True,
                            **args.dataset_kwargs),
        shuffle=True,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs)
    memory_loader = torch.utils.data.DataLoader(
        dataset=get_dataset(transform=get_aug(train=False,
                                              train_classifier=False,
                                              **args.aug_kwargs),
                            train=True,
                            **args.dataset_kwargs),
        shuffle=False,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=get_dataset(
        transform=get_aug(train=False,
                          train_classifier=False,
                          **args.aug_kwargs),
        train=False,
        **args.dataset_kwargs),
                                              shuffle=False,
                                              batch_size=args.train.batch_size,
                                              **args.dataloader_kwargs)

    # define model
    model = get_model(args.model).to(device)
    model = torch.nn.DataParallel(model)

    # define optimizer
    optimizer = get_optimizer(args.train.optimizer.name,
                              model,
                              lr=args.train.base_lr * args.train.batch_size /
                              256,
                              momentum=args.train.optimizer.momentum,
                              weight_decay=args.train.optimizer.weight_decay)

    lr_scheduler = LR_Scheduler(
        optimizer,
        args.train.warmup_epochs,
        args.train.warmup_lr * args.train.batch_size / 256,
        args.train.num_epochs,
        args.train.base_lr * args.train.batch_size / 256,
        args.train.final_lr * args.train.batch_size / 256,
        len(train_loader),
        constant_predictor_lr=True  # see the end of section 4.2 predictor
    )

    logger = Logger(tensorboard=args.logger.tensorboard,
                    matplotlib=args.logger.matplotlib,
                    log_dir=args.log_dir)
    accuracy = 0
    # Start training

    print("Trying to train model {}".format(model))
    print("Will run up to {} epochs of training".format(
        args.train.stop_at_epoch))

    global_progress = tqdm(range(0, args.train.stop_at_epoch),
                           desc=f'Training')
    for epoch in global_progress:
        model.train()

        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.train.num_epochs}',
                              disable=args.hide_progress)
        for idx, _data in enumerate(local_progress):
            # TODO looks like we might be missing the label?
            ((images1, images2), labels) = _data

            model.zero_grad()
            data_dict = model.forward(images1.to(device, non_blocking=True),
                                      images2.to(device, non_blocking=True))
            loss = data_dict['loss'].mean()  # ddp
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            data_dict.update({'lr': lr_scheduler.get_lr()})

            local_progress.set_postfix(data_dict)
            logger.update_scalers(data_dict)

        # ignore KNN monitor since it's coded to work ONLY on cuda enabled devices unfortunately
        # check the mnist yaml to see
        if args.train.knn_monitor and epoch % args.train.knn_interval == 0:
            accuracy = knn_monitor(model.module.backbone,
                                   memory_loader,
                                   test_loader,
                                   device,
                                   k=min(args.train.knn_k,
                                         len(memory_loader.dataset)),
                                   hide_progress=args.hide_progress)

        epoch_dict = {"epoch": epoch, "accuracy": accuracy}
        global_progress.set_postfix(epoch_dict)
        logger.update_scalers(epoch_dict)

    # Save checkpoint
    model_path = os.path.join(
        args.ckpt_dir,
        f"{args.name}_{datetime.now().strftime('%m%d%H%M%S')}.pth"
    )  # datetime.now().strftime('%Y%m%d_%H%M%S')
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.module.state_dict()
    }, model_path)
    print(f"Model saved to {model_path}")
    with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f:
        f.write(f'{model_path}')

    if args.eval is not False:
        args.eval_from = model_path
        linear_eval(args)
示例#17
0
    participantID = args[1]
    path = args[2]
    plot_path = args[3]
    log_path = args[4]

p = Participant(id=participantID, path=path)
p.activeSensingFilenameSelector = 'diary'
p.metaDataFileName = 'meta_patients.json'
p.sleepSummaryFileName = 'FB_sleep_summary.csv'
p.load()

p.pipelineStatus['GP model sim.'] = False
#p.saveSnapshot(p.path)
print(p)

log = Logger(log_path, 'sleepsight' + p.id + '.log', printLog=True)
log.emit('BEGIN ANALYSIS PIPELINE', newRun=True)

# Task: 'trim data' to Study Duration
if not p.isPipelineTaskCompleted('trim data'):
    log.emit('Continuing with TRIM DATA...')
    p.trimData(p.info['startDate'], duration=56)
    p.updatePipelineStatusForTask('trim data', log=log)
    p.saveSnapshot(path, log=log)
else:
    log.emit('Skipping TRIM DATA - already completed.')

# Task: 'missingness' (Decision tree: No missingness vs not worn vs not charged)
if not p.isPipelineTaskCompleted('missingness'):
    log.emit('Continuing with MISSINGNESS computation...')
    mdt = MissingnessDT(passiveData=p.passiveData,
示例#18
0
# -*- coding: utf-8 -*-

from tools import Logger
from abc import ABCMeta

log = Logger(__name__)


class Drivers(metaclass=ABCMeta):
    def __init__(self,
                 instance: object,
                 driver_name: str,
                 driver_parameters=None) -> None:
        self.driver_name = driver_name
        self.driver_parameters = driver_parameters
        self.channel = instance

        log.debug(f'Driver Initialized: {self.driver_name}')

    def start(self):
        pass

    def stop(self):
        pass
示例#19
0
文件: main.py 项目: dign50501/SimSiam
def main(gpu, args):
    rank = args.nr * args.gpus + gpu
    dist.init_process_group("nccl", rank=rank, world_size=args.world_size)

    torch.manual_seed(0)
    torch.cuda.set_device(gpu)

    train_dataset = get_dataset(transform=get_aug(train=True,
                                                  **args.aug_kwargs),
                                train=True,
                                **args.dataset_kwargs)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=args.world_size, rank=rank)

    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        shuffle=False,
        batch_size=(args.train.batch_size // args.gpus),
        sampler=train_sampler,
        **args.dataloader_kwargs)

    memory_dataset = get_dataset(transform=get_aug(train=False,
                                                   train_classifier=False,
                                                   **args.aug_kwargs),
                                 train=True,
                                 **args.dataset_kwargs)

    memory_loader = torch.utils.data.DataLoader(
        dataset=memory_dataset,
        shuffle=False,
        batch_size=(args.train.batch_size // args.gpus),
        **args.dataloader_kwargs)

    test_datset = get_dataset(transform=get_aug(train=False,
                                                train_classifier=False,
                                                **args.aug_kwargs),
                              train=False,
                              **args.dataset_kwargs)

    test_loader = torch.utils.data.DataLoader(
        dataset=test_datset,
        shuffle=False,
        batch_size=(args.train.batch_size // args.gpus),
        **args.dataloader_kwargs)
    print("Batch size:", (args.train.batch_size // args.gpus))
    # define model
    model = get_model(args.model).cuda(gpu)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = DDP(model, device_ids=[gpu], find_unused_parameters=True)

    # define optimizer
    optimizer = get_optimizer(args.train.optimizer.name,
                              model,
                              lr=args.train.base_lr * args.train.batch_size /
                              256,
                              momentum=args.train.optimizer.momentum,
                              weight_decay=args.train.optimizer.weight_decay)

    lr_scheduler = LR_Scheduler(
        optimizer,
        args.train.warmup_epochs,
        args.train.warmup_lr * args.train.batch_size / 256,
        args.train.num_epochs,
        args.train.base_lr * args.train.batch_size / 256,
        args.train.final_lr * args.train.batch_size / 256,
        len(train_loader),
        constant_predictor_lr=True  # see the end of section 4.2 predictor
    )
    if gpu == 0:
        logger = Logger(tensorboard=args.logger.tensorboard,
                        matplotlib=args.logger.matplotlib,
                        log_dir=args.log_dir)
    accuracy = 0
    # Start training
    global_progress = tqdm(range(0, args.train.stop_at_epoch),
                           desc=f'Training')
    for epoch in global_progress:
        model.train()

        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.train.num_epochs}',
                              disable=args.hide_progress)
        for idx, ((images1, images2), labels) in enumerate(local_progress):

            model.zero_grad()
            data_dict = model.forward(images1.cuda(non_blocking=True),
                                      images2.cuda(non_blocking=True))
            loss = data_dict['loss']  # ddp
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            data_dict.update({'lr': lr_scheduler.get_lr()})
            local_progress.set_postfix(data_dict)
            if gpu == 0:
                logger.update_scalers(data_dict)

        if args.train.knn_monitor and epoch % args.train.knn_interval == 0 and gpu == 0:
            accuracy = knn_monitor(model.module.backbone,
                                   memory_loader,
                                   test_loader,
                                   gpu,
                                   k=min(args.train.knn_k,
                                         len(memory_loader.dataset)),
                                   hide_progress=args.hide_progress)

        epoch_dict = {"epoch": epoch, "accuracy": accuracy}
        global_progress.set_postfix(epoch_dict)

        if gpu == 0:
            logger.update_scalers(epoch_dict)

        # Save checkpoint
        if gpu == 0 and epoch % args.train.knn_interval == 0:
            model_path = os.path.join(
                args.ckpt_dir, f"{args.name}_{epoch+1}.pth"
            )  # datetime.now().strftime('%Y%m%d_%H%M%S')
            torch.save(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.module.state_dict()
                }, model_path)
            print(f"Model saved to {model_path}")
            with open(os.path.join(args.log_dir, f"checkpoint_path.txt"),
                      'w+') as f:
                f.write(f'{model_path}')

    # if args.eval is not False and gpu == 0:
    #     args.eval_from = model_path
    #     linear_eval(args)

    dist.destroy_process_group()
示例#20
0
 def __init__(self, basePort):
     self.basePort = basePort
     self.Logger = Logger('9999', 'WebInstanceManager')
     self.MaxInstances = 20
     self.InstanceList = {}
     self.Logger.addLog("Instance Manager: Started")
示例#21
0
def main(config):

    # init loaders and base
    loaders = Loaders(config)
    base = Base(config, loaders)

    # make directions
    make_dirs(base.output_path)
    make_dirs(base.save_model_path)
    make_dirs(base.save_logs_path)
    make_dirs(base.save_visualize_market_path)
    make_dirs(base.save_visualize_duke_path)

    # init logger
    logger = Logger(
        os.path.join(os.path.join(config.output_path, 'logs/'), 'log.txt'))
    logger('\n' * 3)
    logger(config)

    if config.mode == 'train':  # train mode

        # resume model from the resume_train_epoch
        start_train_epoch = 0

        # automatically resume model from the latest one
        if config.auto_resume_training_from_lastest_steps:
            root, _, files = os_walk(base.save_model_path)
            if len(files) > 0:
                # get indexes of saved models
                indexes = []
                for file in files:
                    indexes.append(int(
                        file.replace('.pkl', '').split('_')[-1]))
                indexes = sorted(list(set(indexes)), reverse=False)
                # resume model from the latest model
                base.resume_model(indexes[-1])
                #
                start_train_epoch = indexes[-1]
                logger(
                    'Time: {}, automatically resume training from the latest step (model {})'
                    .format(time_now(), indexes[-1]))

        # main loop
        for current_epoch in range(start_train_epoch,
                                   config.total_train_epochs):
            # save model
            base.save_model(current_epoch)
            # train
            base.lr_scheduler.step(current_epoch)
            _, results = train_an_epoch(config, base, loaders, current_epoch)
            logger('Time: {};  Epoch: {};  {}'.format(time_now(),
                                                      current_epoch, results))
        # test
        testwithVer2(config,
                     logger,
                     base,
                     loaders,
                     'duke',
                     use_gcn=True,
                     use_gm=True)

    elif config.mode == 'test':  # test mode
        # resume from the resume_test_epoch
        if config.resume_test_path != '' and config.resume_test_epoch != 0:
            base.resume_model_from_path(config.resume_test_path,
                                        config.resume_test_epoch)
        else:
            assert 0, 'please set resume_test_path and resume_test_epoch '
        # test
        duke_map, duke_rank = testwithVer2(config,
                                           logger,
                                           base,
                                           loaders,
                                           'duke',
                                           use_gcn=False,
                                           use_gm=False)
        logger('Time: {},  base, Dataset: Duke  \nmAP: {} \nRank: {}'.format(
            time_now(), duke_map, duke_rank))
        duke_map, duke_rank = testwithVer2(config,
                                           logger,
                                           base,
                                           loaders,
                                           'duke',
                                           use_gcn=True,
                                           use_gm=False)
        logger(
            'Time: {},  base+gcn, Dataset: Duke  \nmAP: {} \nRank: {}'.format(
                time_now(), duke_map, duke_rank))
        duke_map, duke_rank = testwithVer2(config,
                                           logger,
                                           base,
                                           loaders,
                                           'duke',
                                           use_gcn=True,
                                           use_gm=True)
        logger('Time: {},  base+gcn+gm, Dataset: Duke  \nmAP: {} \nRank: {}'.
               format(time_now(), duke_map, duke_rank))
        logger('')

    elif config.mode == 'visualize':  # visualization mode
        # resume from the resume_visualize_epoch
        if config.resume_visualize_path != '' and config.resume_visualize_epoch != 0:
            base.resume_model_from_path(config.resume_visualize_path,
                                        config.resume_visualize_epoch)
            print('Time: {}, resume model from {} {}'.format(
                time_now(), config.resume_visualize_path,
                config.resume_visualize_epoch))
        # visualization
        if 'market' in config.train_dataset:
            visualize_ranked_images(config, base, loaders, 'market')
        elif 'duke' in config.train_dataset:
            visualize_ranked_images(config, base, loaders, 'duke')
        else:
            assert 0
示例#22
0
def create_routes():
    bp = Blueprint(__name__, 'nurture')
    loggre = Logger()

    try:
        gpt_t, gpt_m, gpt_a = preLoadGpt2()
        ctrl_t, ctrl_m, ctrl_a = preLoadCtrl()
        topic_engine_flag = 1
        loggre.info(
            'Topic to paragraph generator engine started SUCCESSFULLY!!!')
    except Exception as e:
        loggre.error(
            'Topic to paragraph generator engine loadedERROR!!!')
        raise(e)

    def exception_handler(func):
        def wrapper(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except:
                loggre.error(traceback.format_exc())
                return jsonify(code=1201999, msg='inner error'), 200

        wrapper.__name__ = func.__name__
        return wrapper

    def post(url, data):
        headers = {
            'Content-Type': 'application/json',
            'Accept': 'application/json'
        }
        return requests.post(url, data=json.dumps(data), headers=headers)

    @bp.route("/topic_gen", methods=('post', ))
    @exception_handler
    def topic_gen():
        """
        topic to paragraph文本生成
        """
        body = request.get_json()
        loggre.info(json.dumps(body))

        if 'title' not in body or (not body['title']):
            loggre.error('title is undefined or it is a invalide value')
            return jsonify(
                code=1202002,
                msg='title is undefined or it is a invalide value.'
            )
        if 'length' not in body or not body['length'] or (str.isdigit(str(body['length'])) != True 
                                                          or 100 > int(body['length']) 
                                                          or int(body['length']) > 500):
            loggre.error(
                'length is undefined or it is a invalide value (length must be an Interger.| length must be between 100 and 500)')
            return jsonify(
                code=1202002,
                msg='length is undefined or it is a invalide value (length must be an Interger.| length must be between 100 and 500)'
            ), 200
        if 'nums' not in body or not body['nums'] or (str.isdigit(str(body['nums'])) != True 
                                                          or 1 > int(body['nums']) 
                                                          or int(body['nums']) > 20):
            loggre.error(
                'nums is undefined or it is a invalide value. (nums must be an Interger.| nums must be between 1 and 20)')
            return jsonify(
                code=1202002,
                msg='nums is undefined or it is a invalide value. (nums must be an Interger.| nums must be between 1 and 20)'
            ), 200
        loggre.info(body)
        
        title=str(body['title'])
        length=int(body['length'])
        nums=int(body['nums'])

        if topic_engine_flag == 1:

            ctrl_sample_nums = int(nums / 1.5) + nums % 2
            gpt2_sample_nums = nums - ctrl_sample_nums

            text_ctrl = genText(ctrl_t, ctrl_m, ctrl_a, length,
                                ctrl_sample_nums, "News " + title)
            print(text_ctrl)
            if gpt2_sample_nums != 0:
                text_gpt2 = genText(gpt_t, gpt_m, gpt_a, length,
                                    gpt2_sample_nums, title + " | ")
            else:
                text_gpt2 = []
            text_list = split_combine_text(text_ctrl, text_gpt2)
            return jsonify(code=0, msg=0, data=text_list), 200
        else:
            loggre.error(
                'Topic to paragraph generator engine is not loaded !')
            return jsonify(
                code=1202002,
                msg='Topic to paragraph generator engine is not loaded !'
            ), 200

    return bp
示例#23
0
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)
    print('==> Preparing dataset %s' % args.dataset)

    if args.dataset == 'cifar100':
        training_loader = get_training_dataloader(settings.CIFAR100_TRAIN_MEAN,
                                                  settings.CIFAR100_TRAIN_STD,
                                                  num_workers=4,
                                                  batch_size=args.train_batch,
                                                  shuffle=True)

        test_loader = get_test_dataloader(settings.CIFAR100_TRAIN_MEAN,
                                          settings.CIFAR100_TRAIN_STD,
                                          num_workers=4,
                                          batch_size=args.test_batch,
                                          shuffle=False)
        num_classes = 100
    else:
        training_loader = get_training_dataloader_10(
            settings.CIFAR10_TRAIN_MEAN,
            settings.CIFAR10_TRAIN_STD,
            num_workers=4,
            batch_size=args.train_batch,
            shuffle=True)

        test_loader = get_test_dataloader_10(settings.CIFAR10_TRAIN_MEAN,
                                             settings.CIFAR10_TRAIN_STD,
                                             num_workers=4,
                                             batch_size=args.test_batch,
                                             shuffle=False)
        num_classes = 10
    #data preprocessing:
    print("==> creating model '{}'".format(args.arch))

    model = get_network(args, num_classes=num_classes)
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))
    criterion1 = am_softmax.AMSoftmax()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    title = 'cifar-10-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.checkpoint = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names([
            'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.',
            'Valid Acc.'
        ])

    train_scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.schedule, gamma=0.2)  #learning rate decay
    iter_per_epoch = len(training_loader)
    warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warm)

    for epoch in range(start_epoch, args.epochs):
        if epoch > args.warm:
            train_scheduler.step(epoch)
        train_loss, train_acc = train(training_loader, model, warmup_scheduler,
                                      criterion, criterion1, optimizer, epoch,
                                      use_cuda)
        test_loss, test_acc = eval_training(test_loader, model, criterion,
                                            epoch, use_cuda)

        logger.append([
            optimizer.param_groups[0]['lr'], train_loss, test_loss, train_acc,
            test_acc
        ])

        # save model
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            checkpoint=args.checkpoint)

    logger.close()
    # logger.plot()
    # savefig(os.path.join(args.checkpoint, 'log.eps'))

    print('Best acc:')
    print(best_acc)
示例#24
0
文件: turing.py 项目: mocurin/MLaAT
from typing import NamedTuple, List, Tuple, Dict, Iterable
from enum import IntEnum
from builtins import str
from tools import Logger

# Logging utils
_turing_logger = Logger()
_turing_logger['start'] = lambda tape: print(f"Initial tape: '{tape}'")
_turing_logger['stop'] = lambda tape: print(f"Result tape: '{tape}'")
_turing_logger['process'] = lambda pos, cond, rep, tape: print(
    f"I:{pos}, St:{cond.state}; Rule: (q{cond.state},{cond.letter})->(q{rep.state},"
    f"{rep.letter if rep.letter else 'Empty'},{rep.move.name}); Tape: '{tape}';"
)


class Move(IntEnum):
    L = -1
    St = 0
    R = 1


# Named tuples are especially good there, as they provide sort of
# readable access to tuple members. Though, now i have to convert tuples
# to named tuples
Condition = NamedTuple('Condition', [('state', int), ('letter', str)])
Replacement = NamedTuple('Replacement', [('state', int), ('letter', str),
                                         ('move', Move)])
Rule = Tuple[Condition, Replacement]


# Dict is too muсh for such a simple structure. Also tried
示例#25
0
from tools import Logger
import config
from evaluation import evaluate as svm_classify

import os
os.environ['CUDA_VISIBLE_DEVICES'] = str(config.GPU_ID)

#load data
dataset_choice = config.dataset_choice
graph_path = './datasets/%s/graph.txt' % dataset_choice  # train_
text_path = './datasets/%s/data.txt' % dataset_choice
group_path = './datasets/%s/group.txt' % dataset_choice

model_name = 'DetGP'
log_save_path = './logs/'
log = Logger(log_save_path + model_name + '.txt')
data = DataLoader(text_path, graph_path, 1.)


def eval_node_classification(sess, model, model_name, ratio):
    text_emb, struct_emb = sess.run(
        [model.text_emb_a, model.struct_emb_a],
        feed_dict={
            # model.edges: data.edges,
            # model.text_all: data.text,
            model.node_a_ids:
            np.arange(data.num_nodes)
        })
    node_emb = np.concatenate((text_emb, struct_emb), axis=1)
    seen_nodes = data.train_graph.nodes()
示例#26
0
def main(config):

    # loaders and base
    loaders = Loaders(config)
    base = Base(config, loaders)

    # make dirs
    make_dirs(config.save_images_path)
    make_dirs(config.save_models_path)
    make_dirs(config.save_features_path)

    # logger
    logger = Logger(os.path.join(config.output_path, 'log.txt'))
    logger(config)

    if config.mode == 'train':

        # automatically resume model from the latest one
        start_train_epoch = 0
        root, _, files = os_walk(config.save_models_path)
        if len(files) > 0:
            # get indexes of saved models
            indexes = []
            for file in files:
                indexes.append(int(file.replace('.pkl', '').split('_')[-1]))

            # remove the bad-case and get available indexes
            model_num = len(base.model_list)
            available_indexes = copy.deepcopy(indexes)
            for element in indexes:
                if indexes.count(element) < model_num:
                    available_indexes.remove(element)

            available_indexes = sorted(list(set(available_indexes)),
                                       reverse=True)
            unavailable_indexes = list(
                set(indexes).difference(set(available_indexes)))

            if len(available_indexes
                   ) > 0:  # resume model from the latest model
                base.resume_model(available_indexes[0])
                start_train_epoch = available_indexes[0] + 1
                logger(
                    'Time: {}, automatically resume training from the latest step (model {})'
                    .format(time_now(), available_indexes[0]))
            else:  #
                logger('Time: {}, there are no available models')

        # main loop
        for current_epoch in range(
                start_train_epoch, config.warmup_reid_epoches +
                config.warmup_gan_epoches + config.train_epoches):

            # test
            if current_epoch % 10 == 0 and current_epoch > config.warmup_reid_epoches + config.warmup_gan_epoches:
                results = test(config, base, loaders, brief=True)
                for key in results.keys():
                    logger('Time: {}\n Setting: {}\n {}'.format(
                        time_now(), key, results[key]))

            # visualize generated images
            if current_epoch % 10 == 0 or current_epoch <= 10:
                visualize(config, loaders, base, current_epoch)

            # train
            if current_epoch < config.warmup_reid_epoches:  # warmup reid model
                results = train_an_epoch(config,
                                         loaders,
                                         base,
                                         current_epoch,
                                         train_gan=True,
                                         train_reid=True,
                                         train_pixel=False,
                                         optimize_sl_enc=True)
            elif current_epoch < config.warmup_reid_epoches + config.warmup_gan_epoches:  # warmup GAN model
                results = train_an_epoch(config,
                                         loaders,
                                         base,
                                         current_epoch,
                                         train_gan=True,
                                         train_reid=False,
                                         train_pixel=False,
                                         optimize_sl_enc=False)
            else:  # joint train
                results = train_an_epoch(config,
                                         loaders,
                                         base,
                                         current_epoch,
                                         train_gan=True,
                                         train_reid=True,
                                         train_pixel=True,
                                         optimize_sl_enc=True)
            logger('Time: {};  Epoch: {};  {}'.format(time_now(),
                                                      current_epoch, results))

            # save model
            base.save_model(current_epoch)

        # test
        results = test(config, base, loaders, brief=False)
        for key in results.keys():
            logger('Time: {}\n Setting: {}\n {}'.format(
                time_now(), key, results[key]))

    elif config.mode == 'test':
        # resume from pre-trained model and test
        base.resume_model_from_path(config.pretrained_model_path,
                                    config.pretrained_model_epoch)
        results = test(config, base, loaders, brief=False)
        for key in results.keys():
            logger('Time: {}\n Setting: {}\n {}'.format(
                time_now(), key, results[key]))
import sys
import os
from tools import Logger

log_dir = ''
log_dir += sys.argv[1]

if log_dir is '':
    print('[ERROR] Please provide the log directory')
    exit()
log = Logger(log_dir, '*log_of_logs.log', printLog=True, timestampOn=False)
log.emit('NEW STATUS REPORT', newRun=True)

filenames = os.listdir(log_dir)
filenames.sort()

sleepsightLogs = []
for filename in filenames:
    if 'sleepsight' in filename:
        ssLog = Logger(log_dir, filename)
        log.emit('{}\t{}'.format(filename, ssLog.getLastMessage()))

path = '/Users/Kerz/Documents/projects/SleepSight/ANALYSIS/data/'
plot_path = '/Users/Kerz/Documents/projects/SleepSight/ANALYSIS/plots/'
log_path = '/Users/Kerz/Documents/projects/SleepSight/ANALYSIS/logs/'

options = {'periodicity': False,
           'participant-info': False,
           'compliance': False,
           'stationarity': False,
           'symptom-score-discretisation': False,
           'feature-delay': False,
           'feature-selection': False,
           'non-parametric-svm': False,
           'non-parametric-gp': True
           }

log = Logger(log_path, 'thesis_outputs.log', printLog=True)

# Load Participants
log.emit('Loading participants...', newRun=True)
aggr = T.Aggregates('.pkl', path, plot_path)


# Export Periodicity tables
if options['periodicity']:
    log.emit('Generating PERIODCITY table...')
    pt = T.PeriodictyTable(aggr, log)
    pt.run()
    pt.exportLatexTable(summary=False)
    pt.exportLatexTable(summary=True)

示例#29
0
from typing import Callable
import numpy as np
from sklearn.utils import shuffle
from sklearn import metrics
from random import seed
import time
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import torch.utils.data as utils
from tools import Logger, accuracy, TotalMeter

logger = Logger()

partition = np.loadtxt('example_adjacency.txt', dtype=int, delimiter=None)
expression = np.loadtxt('example_expression.csv', dtype=float, delimiter=",")
labels = np.array(expression[:, -1], dtype=int)
expression = np.array(expression[:, :-1])

# train/test data split
cut = int(0.8 * np.shape(expression)[0])
expression, labels = shuffle(expression, labels)
x_train = expression[:cut, :]
x_test = expression[cut:, :]
y_train = labels[:cut]
y_test = labels[cut:]
print(x_test.shape)
partition = torch.from_numpy(partition).float()
def main(device, args):
    train_directory = '../data/train'
    image_name_file = '../data/original.csv'
    val_directory = '../data/train'
    train_loader = torch.utils.data.DataLoader(
        dataset=get_dataset('random', train_directory, image_name_file,
            transform=get_aug(train=True, **args.aug_kwargs),
            train=True,
            **args.dataset_kwargs),
        # dataset=datasets.ImageFolder(root=train_directory, transform=get_aug(train=True, **args.aug_kwargs)),
        shuffle=True,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs
    )

    memory_loader = torch.utils.data.DataLoader(
        dataset=datasets.ImageFolder(root=val_directory, transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs)),
        shuffle=False,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=datasets.ImageFolder(root=val_directory, transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs)),
        shuffle=False,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs
    )

    # define model
    model = get_model(args.model).to(device)
    model = torch.nn.DataParallel(model)
    scaler = torch.cuda.amp.GradScaler()

    # define optimizer
    optimizer = get_optimizer(
        args.train.optimizer.name, model,
        lr=args.train.base_lr * args.train.batch_size / 256,
        momentum=args.train.optimizer.momentum,
        weight_decay=args.train.optimizer.weight_decay)

    lr_scheduler = LR_Scheduler(
        optimizer,
        args.train.warmup_epochs, args.train.warmup_lr * args.train.batch_size / 256,
        args.train.num_epochs, args.train.base_lr * args.train.batch_size / 256,
                                  args.train.final_lr * args.train.batch_size / 256,
        len(train_loader),
        constant_predictor_lr=True  # see the end of section 4.2 predictor
    )

    RESUME = False
    start_epoch = 0

    if RESUME:
        model = get_backbone(args.model.backbone)
        classifier = nn.Linear(in_features=model.output_dim, out_features=9, bias=True).to(args.device)

        assert args.eval_from is not None
        save_dict = torch.load(args.eval_from, map_location='cpu')
        msg = model.load_state_dict({k[9:]: v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')},
                                    strict=True)

        path_checkpoint = "./checkpoint/simsiam-TCGA-0218-nearby_0221134812.pth"  # 断点路径
        checkpoint = torch.load(path_checkpoint)  # 加载断点

        model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

        optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
        start_epoch = checkpoint['epoch']  # 设置开始的epoch

    logger = Logger(tensorboard=args.logger.tensorboard, matplotlib=args.logger.matplotlib, log_dir=args.log_dir)
    accuracy = 0
    # Start training
    global_progress = tqdm(range(start_epoch, args.train.stop_at_epoch), desc=f'Training')
    for epoch in global_progress:
        model.train()

        local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.train.num_epochs}', disable=args.hide_progress)
        for idx, (images1, images2, images3, labels) in enumerate(local_progress):
            model.zero_grad()
            with torch.cuda.amp.autocast():
                data_dict = model.forward(images1.to(device, non_blocking=True), images2.to(device, non_blocking=True),
                                          images3.to(device, non_blocking=True))
                loss = data_dict['loss'].mean()  # ddp
            # loss.backward()
            scaler.scale(loss).backward()
            # optimizer.step()
            scaler.step(optimizer)
            scaler.update()

            lr_scheduler.step()
            data_dict.update({'lr': lr_scheduler.get_lr()})

            local_progress.set_postfix(data_dict)
            logger.update_scalers(data_dict)

        if args.train.knn_monitor and epoch % args.train.knn_interval == 0:
            accuracy = knn_monitor(model.module.backbone, memory_loader, test_loader, device,
                                   k=min(args.train.knn_k, len(memory_loader.dataset)),
                                   hide_progress=args.hide_progress)

        epoch_dict = {"epoch": epoch, "accuracy": accuracy}
        global_progress.set_postfix(epoch_dict)
        logger.update_scalers(epoch_dict)

        checkpoint = {
            "net": model.state_dict(),
            'optimizer': optimizer.state_dict(),
            "epoch": epoch
        }
        if (epoch % args.train.save_interval) == 0:
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.module.state_dict()
            }, './checkpoint/exp_0223_triple_400_proj3/ckpt_best_%s.pth' % (str(epoch)))

    # Save checkpoint
    model_path = os.path.join(args.ckpt_dir,
                              f"{args.name}_{datetime.now().strftime('%m%d%H%M%S')}.pth")  # datetime.now().strftime('%Y%m%d_%H%M%S')
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.module.state_dict()
    }, model_path)
    print(f"Model saved to {model_path}")
    with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f:
        f.write(f'{model_path}')


    if args.eval is not False:
        args.eval_from = model_path
        linear_eval(args)