def stacking(self,
                 train_loader,
                 val_loader,
                 epochs=10,
                 eval_metric='loss'):

        if not isinstance(self.model, Ensemble):
            raise Exception("model must be Ensemble!!!")

        optimizer = get_optimizer(self.stacking_model.parameters(), self.cfg)
        lambda1 = lambda epoch: 0.9**epoch
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lr_lambda=lambda1)

        os.makedirs(os.path.join('experiment', self.cfg.log_dir),
                    exist_ok=True)
        ckp_dir = os.path.join('experiment', self.cfg.log_dir, 'checkpoint')
        os.makedirs(ckp_dir, exist_ok=True)

        self.model.freeze()
        self.stacking_model.unfreeze()
        self.stacking_model.cuda()

        running_loss = AverageMeter()
        best_metric = 0.0

        for epoch in range(epochs):
            self.stacking_model.train()
            for i, data in enumerate(tqdm.tqdm(train_loader)):
                imgs, labels = data[0].to(self.device), data[1].to(self.device)
                preds = self.stacking_model(self.model(imgs))
                loss = self.metrics['loss'](preds, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                running_loss.update(loss.item(), imgs.shape[0])
            s = "Epoch [{}/{}]:\n".format(epoch + 1, epochs)
            s += "{}_{} {:.3f}\n".format('train', 'loss', running_loss.avg)
            self.stacking_model.eval()
            running_metrics = self.test(val_loader)
            running_metrics.pop('loss')
            s = get_str(running_metrics, 'val', s)
            metric_eval = running_metrics[eval_metric]
            s = s[:-1] + "- mean_"+eval_metric + \
                " {:.3f}".format(metric_eval.mean())
            torch.save(self.stacking_model.state_dict(),
                       os.path.join(ckp_dir, 'latest.ckpt'))
            running_loss.reset()
            scheduler.step()
            print(s)
            if metric_eval.mean() > best_metric:
                best_metric = metric_eval.mean()
                shutil.copyfile(os.path.join(ckp_dir, 'latest.ckpt'),
                                os.path.join(ckp_dir, 'best.ckpt'))
                print('new checkpoint saved!')
Esempio n. 2
0
    def __init__(self, cfg, loss_func, metrics=None):
        """CheXpert class contains all functions used for training and testing our models

        Args:
            cfg (dict): configuration file.
            loss_func (torch.nn.Module): loss function of the model.
            metrics (dict, optional): metrics use to evaluate model performance. Defaults to None.
        """
        self.cfg = cfg
        if self.cfg.full_classes:
            self.cfg.num_classes = 14 * [1]
        self.model, self.childs_cut = get_model(self.cfg)
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.loss_func = loss_func
        if metrics is not None:
            self.metrics = metrics
            self.metrics['loss'] = self.loss_func
        else:
            self.metrics = {'loss': self.loss_func}
        self.optimizer = get_optimizer(self.model.parameters(), self.cfg)
        self.model.to(self.device)
Esempio n. 3
0
def run(args):
    with open(args.cfg_path) as f:
        cfg = edict(json.load(f))
        if args.verbose is True:
            print(json.dumps(cfg, indent=4))

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    if args.logtofile is True:
        logging.basicConfig(filename=args.save_path + '/log.txt',
                            filemode="w",
                            level=logging.INFO)
    else:
        logging.basicConfig(level=logging.INFO)

    if not args.resume:
        with open(os.path.join(args.save_path, 'cfg.json'), 'w') as f:
            json.dump(cfg, f, indent=1)

    device_ids = list(map(int, args.device_ids.split(',')))
    num_devices = torch.cuda.device_count()
    if num_devices < len(device_ids):
        raise Exception('#available gpu : {} < --device_ids : {}'.format(
            num_devices, len(device_ids)))
    device = torch.device('cuda:{}'.format(device_ids[0]))

    model = Classifier(cfg)
    if args.verbose is True:
        from torchsummary import summary
        if cfg.fix_ratio:
            h, w = cfg.long_side, cfg.long_side
        else:
            h, w = cfg.height, cfg.width
        summary(model.to(device), (3, h, w))
    model = DataParallel(model, device_ids=device_ids).to(device).train()
    if args.pre_train is not None:
        if os.path.exists(args.pre_train):
            ckpt = torch.load(args.pre_train, map_location=device)
            model.module.load_state_dict(ckpt)
    optimizer = get_optimizer(model.parameters(), cfg)

    src_folder = os.path.dirname(os.path.abspath(__file__)) + '/../'
    dst_folder = os.path.join(args.save_path, 'classification')
    # rc, size = subprocess.getstatusoutput('dir --max-depth=0 %s | cut -f1'
    #                                       % src_folder)
    # if rc != 0:
    #     print(size)
    #     raise Exception('Copy folder error : {}'.format(rc))
    # rc, err_msg = subprocess.getstatusoutput('cp -R %s %s' % (src_folder,
    #                                                           dst_folder))
    # if rc != 0:
    #     raise Exception('copy folder error : {}'.format(err_msg))

    copyfile(cfg.train_csv, os.path.join(args.save_path, 'train.csv'))
    copyfile(cfg.dev_csv, os.path.join(args.save_path, 'valid.csv'))

    dataloader_train = DataLoader(ImageDataset(cfg.train_csv,
                                               cfg,
                                               mode='train'),
                                  batch_size=cfg.train_batch_size,
                                  num_workers=args.num_workers,
                                  drop_last=True,
                                  shuffle=True)
    dataloader_dev = DataLoader(ImageDataset(cfg.dev_csv, cfg, mode='dev'),
                                batch_size=cfg.dev_batch_size,
                                num_workers=args.num_workers,
                                drop_last=False,
                                shuffle=False)
    dev_header = dataloader_dev.dataset._label_header

    summary_train = {'epoch': 0, 'step': 0}
    summary_dev = {'loss': float('inf'), 'acc': 0.0}
    summary_writer = SummaryWriter(args.save_path)
    epoch_start = 0
    best_dict = {
        "acc_dev_best": 0.0,
        "auc_dev_best": 0.0,
        "loss_dev_best": float('inf'),
        "fused_dev_best": 0.0,
        "best_idx": 1
    }

    if args.resume:
        ckpt_path = os.path.join(args.save_path, 'train.ckpt')
        ckpt = torch.load(ckpt_path, map_location=device)
        model.module.load_state_dict(ckpt['state_dict'])
        summary_train = {'epoch': ckpt['epoch'], 'step': ckpt['step']}
        best_dict['acc_dev_best'] = ckpt['acc_dev_best']
        best_dict['loss_dev_best'] = ckpt['loss_dev_best']
        best_dict['auc_dev_best'] = ckpt['auc_dev_best']
        epoch_start = ckpt['epoch']

    for epoch in range(epoch_start, cfg.epoch):
        lr = lr_schedule(cfg.lr, cfg.lr_factor, summary_train['epoch'],
                         cfg.lr_epochs)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        summary_train, best_dict = train_epoch(summary_train, summary_dev, cfg,
                                               args, model, dataloader_train,
                                               dataloader_dev, optimizer,
                                               summary_writer, best_dict,
                                               dev_header)

        time_now = time.time()
        summary_dev, predlist, true_list = test_epoch(summary_dev, cfg, args,
                                                      model, dataloader_dev)
        time_spent = time.time() - time_now

        auclist = []
        for i in range(len(cfg.num_classes)):
            y_pred = predlist[i]
            y_true = true_list[i]
            fpr, tpr, thresholds = metrics.roc_curve(y_true,
                                                     y_pred,
                                                     pos_label=1)
            auc = metrics.auc(fpr, tpr)
            auclist.append(auc)
        summary_dev['auc'] = np.array(auclist)

        loss_dev_str = ' '.join(
            map(lambda x: '{:.5f}'.format(x), summary_dev['loss']))
        acc_dev_str = ' '.join(
            map(lambda x: '{:.3f}'.format(x), summary_dev['acc']))
        auc_dev_str = ' '.join(
            map(lambda x: '{:.3f}'.format(x), summary_dev['auc']))

        logging.info('{}, Dev, Step : {}, Loss : {}, Acc : {}, Auc : {},'
                     'Mean auc: {:.3f} '
                     'Run Time : {:.2f} sec'.format(
                         time.strftime("%Y-%m-%d %H:%M:%S"),
                         summary_train['step'], loss_dev_str, acc_dev_str,
                         auc_dev_str, summary_dev['auc'].mean(), time_spent))

        for t in range(len(cfg.num_classes)):
            summary_writer.add_scalar('dev/loss_{}'.format(dev_header[t]),
                                      summary_dev['loss'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/acc_{}'.format(dev_header[t]),
                                      summary_dev['acc'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/auc_{}'.format(dev_header[t]),
                                      summary_dev['auc'][t],
                                      summary_train['step'])

        save_best = False

        mean_acc = summary_dev['acc'][cfg.save_index].mean()
        if mean_acc >= best_dict['acc_dev_best']:
            best_dict['acc_dev_best'] = mean_acc
            if cfg.best_target == 'acc':
                save_best = True

        mean_auc = summary_dev['auc'][cfg.save_index].mean()
        if mean_auc >= best_dict['auc_dev_best']:
            best_dict['auc_dev_best'] = mean_auc
            if cfg.best_target == 'auc':
                save_best = True

        mean_loss = summary_dev['loss'][cfg.save_index].mean()
        if mean_loss <= best_dict['loss_dev_best']:
            best_dict['loss_dev_best'] = mean_loss
            if cfg.best_target == 'loss':
                save_best = True

        if save_best:
            torch.save(
                {
                    'epoch': summary_train['epoch'],
                    'step': summary_train['step'],
                    'acc_dev_best': best_dict['acc_dev_best'],
                    'auc_dev_best': best_dict['auc_dev_best'],
                    'loss_dev_best': best_dict['loss_dev_best'],
                    'state_dict': model.module.state_dict()
                },
                os.path.join(args.save_path,
                             'best{}.ckpt'.format(best_dict['best_idx'])))
            best_dict['best_idx'] += 1
            if best_dict['best_idx'] > cfg.save_top_k:
                best_dict['best_idx'] = 1
            logging.info('{}, Best, Step : {}, Loss : {}, Acc : {},'
                         'Auc :{},Best Auc : {:.3f}'.format(
                             time.strftime("%Y-%m-%d %H:%M:%S"),
                             summary_train['step'], loss_dev_str, acc_dev_str,
                             auc_dev_str, best_dict['auc_dev_best']))
        torch.save(
            {
                'epoch': summary_train['epoch'],
                'step': summary_train['step'],
                'acc_dev_best': best_dict['acc_dev_best'],
                'auc_dev_best': best_dict['auc_dev_best'],
                'loss_dev_best': best_dict['loss_dev_best'],
                'state_dict': model.module.state_dict()
            }, os.path.join(args.save_path, 'train.ckpt'))
    summary_writer.close()
Esempio n. 4
0
def train(hyperparams: Hyperparameter):
    # -- hyperparams -- #
    dataset_params = hyperparams.subparams("dataset")
    config_params = hyperparams.subparams("configuration")
    train_params = hyperparams.subparams("train")
    model_params = hyperparams.subparams("model")
    output_params = hyperparams.subparams("output")

    os.makedirs(output_params.root_dir, exist_ok=True)
    if hasattr(output_params, "logname"):
        log.basicConfig(filename=os.path.join(output_params.root_dir,
                                              output_params.logname),
                        filemode="w",
                        level=get_log_level(output_params.log_level))
    else:
        log.basicConfig(level=get_log_level(output_params.log_level))
    hyperparams.save(os.path.join(output_params.root_dir, "hyperparams.json"))
    atomic_num_list = get_atomic_num_id(
        os.path.join(config_params.root_dir,
                     config_params.atom_id_to_atomic_num))

    data_parallel = False
    if isinstance(train_params.device, int):
        main_device = train_params.device
        device = main_device
    elif isinstance(train_params.device, dict):
        main_device = train_params.device["main"]
        device = train_params.device
        data_parallel = True
    else:
        raise ValueError("Invalid device.")
    log.info("Main Device: {}".format(main_device))

    log.info("dataset hyperparameters:\n{}\n".format(dataset_params))
    log.info("configuration hyperparameters:\n{}\n".format(config_params))
    log.info("train hyperparameters:\n{}\n".format(train_params))
    log.info("model hyperparameters:\n{}\n".format(model_params))
    log.info("output hyperparameters:\n{}\n".format(output_params))

    # -- build dataset -- #
    if config_params.has("train_validation_split"):
        validation_idxs = get_validation_idxs(
            os.path.join(config_params.root_dir,
                         config_params.train_validation_split))
    else:
        validation_idxs = None

    dataset = NumpyTupleDataset.load(
        os.path.join(dataset_params.root_dir, dataset_params.name))
    if validation_idxs:
        train_idxs = [
            i for i in range(len(dataset)) if i not in validation_idxs
        ]
        trainset_size = len(train_idxs)
        train_idxs.extend(validation_idxs)
        trainset, valset = chainer.datasets.split_dataset(
            dataset, trainset_size, train_idxs)
    else:
        trainset, valset = chainer.datasets.split_dataset_random(
            dataset, int(len(dataset) * 0.8), seed=777)

    train_iter = chainer.iterators.SerialIterator(trainset,
                                                  train_params.batch_size,
                                                  shuffle=True)
    val_iter = chainer.iterators.SerialIterator(valset,
                                                train_params.batch_size,
                                                repeat=False,
                                                shuffle=False)

    # -- model -- #
    model = AttentionNvpModel(model_params)
    if isinstance(device, dict):
        log.info("Using multi-GPU {}".format(device))
        model.to_gpu(main_device)
    elif device >= 0:
        log.info("Using GPU {}".format(device))
        chainer.cuda.get_device(main_device).use()
        model.to_gpu(device)
    else:
        log.info("Using CPU")

    # -- training details -- #
    num_epoch = train_params.num_epoch
    opt_gen = get_optimizer(train_params.optimizer)
    if train_params.has("optimizer_params"):
        optimizer = opt_gen(**train_params.optimizer_params)
    else:
        optimizer = opt_gen()

    optimizer.setup(model)
    if data_parallel:
        updater = DataParallelNVPUpdater(
            train_iter,
            optimizer,
            devices=device,
            two_step=train_params.two_step,
            h_nll_weight=train_params.h_nll_weight)
    else:
        updater = NVPUpdater(train_iter,
                             optimizer,
                             device=device,
                             two_step=train_params.two_step,
                             h_nll_weight=train_params.h_nll_weight)
    trainer = training.Trainer(updater, (num_epoch, "epoch"),
                               out=output_params.root_dir)
    if train_params.has("save_epoch"):
        save_epoch = train_params.save_epoch
    else:
        save_epoch = num_epoch

    # -- evaluation function -- #
    def print_validity(trainer):
        with chainer.using_device(
                chainer.backends.cuda.get_device_from_id(
                    main_device)), chainer.using_config("train", False):
            save_mol = (get_log_level(output_params.log_level) <= log.DEBUG)
            x, adj = generate_mols(model, batch_size=100,
                                   device=main_device)  # x: atom id
            valid_mols = check_validity(x,
                                        adj,
                                        atomic_num_list=atomic_num_list,
                                        device=main_device)
            if save_mol:
                mol_dir = os.path.join(
                    output_params.root_dir, output_params.saved_mol_dir,
                    "generated_{}".format(trainer.updater.epoch))
                os.makedirs(mol_dir, exist_ok=True)
                for i, mol in enumerate(valid_mols["valid_mols"]):
                    save_mol_png(mol, os.path.join(mol_dir,
                                                   "{}.png".format(i)))

    # -- trainer extension -- #
    trainer.extend(extensions.snapshot(), trigger=(save_epoch, "epoch"))
    trainer.extend(extensions.LogReport(filename=output_params.trainlogname))
    trainer.extend(print_validity, trigger=(1, "epoch"))
    trainer.extend(
        extensions.PrintReport([
            "epoch", "neg_log_likelihood", "nll_x", "nll_adj", "z_var",
            "ln_det_x", "ln_det_adj", "elapsed_time"
        ]))
    trainer.extend(extensions.ProgressBar())

    # -- start train -- #
    if hasattr(train_params, "load_snapshot"):
        log.info("Load snapshot from {}".format(train_params.load_snapshot))
        chainer.serializers.load_npz(train_params.load_snapshot, trainer)
    trainer.run()
    chainer.serializers.save_npz(
        os.path.join(output_params.root_dir, output_params.final_model_name),
        model)
Esempio n. 5
0
def run(args, val_h5_file):
    with open(args.cfg_path) as f:
        cfg = edict(json.load(f))
        if args.verbose is True:
            print(json.dumps(cfg, indent=4))

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    if args.logtofile is True:
        logging.basicConfig(filename=args.save_path + '/log.txt',
                            filemode="w",
                            level=logging.INFO)
    else:
        logging.basicConfig(level=logging.INFO)

    if not args.resume:
        with open(os.path.join(args.save_path, 'cfg.json'), 'w') as f:
            json.dump(cfg, f, indent=1)

    device_ids = list(map(int, args.device_ids.split(',')))
    num_devices = torch.cuda.device_count()
    if num_devices < len(device_ids):
        raise Exception('#available gpu : {} < --device_ids : {}'.format(
            num_devices, len(device_ids)))
    device = torch.device('cuda:{}'.format(device_ids[0]))

    model = Classifier(cfg)
    if args.verbose is True:
        from torchsummary import summary
        if cfg.fix_ratio:
            h, w = cfg.long_side, cfg.long_side
        else:
            h, w = cfg.height, cfg.width
        summary(model.to(device), (3, h, w))
    model = DataParallel(model, device_ids=device_ids).to(device).train()
    if args.pre_train is not None:
        if os.path.exists(args.pre_train):
            ckpt = torch.load(args.pre_train, map_location=device)
            model.module.load_state_dict(ckpt)
    optimizer = get_optimizer(model.parameters(), cfg)

    #src_folder = os.path.dirname(os.path.abspath(__file__)) + '/../'
    #dst_folder = os.path.join(args.save_path, 'classification')
    #rc, size = subprocess.getstatusoutput('du --max-depth=0 %s | cut -f1' % src_folder)
    #if rc != 0: raise Exception('Copy folder error : {}'.format(rc))
    #rc, err_msg = subprocess.getstatusoutput('cp -R %s %s' % (src_folder, dst_folder))
    #if rc != 0: raise Exception('copy folder error : {}'.format(err_msg))
    #copyfile(cfg.train_csv, os.path.join(args.save_path, 'train.csv'))
    #copyfile(cfg.dev_csv, os.path.join(args.save_path, 'dev.csv'))
    # np_train_h5_file = np.array(train_h5_file['train'][:10000], dtype=np.uint8)
    # np_t_u_ones = np.array(train_h5_file['train_u_ones'][:10000], dtype=np.int8)
    # np_t_u_zeros = np.array(train_h5_file['train_u_zeros'][:10000], dtype=np.int8)
    # np_t_u_random = np.array(train_h5_file['train_u_random'][:10000], dtype=np.int8)

    np_val_h5_file = np.array(val_h5_file['val'], dtype=np.uint8)
    np_v_u_ones = np.array(val_h5_file['val_u_ones'], dtype=np.int8)
    np_v_u_zeros = np.array(val_h5_file['val_u_zeros'], dtype=np.int8)
    np_v_u_random = np.array(val_h5_file['val_u_random'], dtype=np.int8)

    train_labels = {}
    with h5py.File(f'{args.train_chunks}/train_labels.h5', 'r') as fp:
        train_labels['train_u_ones'] = np.array(fp['train_u_ones'],
                                                dtype=np.int8)
        train_labels['train_u_zeros'] = np.array(fp['train_u_zeros'],
                                                 dtype=np.int8)
        train_labels['train_u_random'] = np.array(fp['train_u_random'],
                                                  dtype=np.int8)
    np_train_samples = None
    for i in range(args.chunk_count):
        with open(f'{args.train_chunks}/chexpert_dset_chunk_{i+1}.npy',
                  'rb') as f:
            if np_train_samples is None:
                np_train_samples = np.load(f)
            else:
                np_train_samples = np.concatenate(
                    (np_train_samples, np.load(f)))

    dataloader_train = DataLoader(ImageDataset(
        [np_train_samples, train_labels], cfg, mode='train'),
                                  batch_size=cfg.train_batch_size,
                                  num_workers=args.num_workers,
                                  drop_last=True,
                                  shuffle=True)

    dataloader_dev = DataLoader(ImageDataset(
        [np_val_h5_file, np_v_u_zeros, np_v_u_ones, np_v_u_random],
        cfg,
        mode='val'),
                                batch_size=cfg.dev_batch_size,
                                num_workers=args.num_workers,
                                drop_last=False,
                                shuffle=False)
    #dev_header = dataloader_dev.dataset._label_header
    dev_header = [
        'No_Finding', 'Enlarged_Cardiomediastinum', 'Cardiomegaly',
        'Lung_Opacity', 'Lung_Lesion', 'Edema', 'Consolidation', 'Pneumonia',
        'Atelectasis', 'Pneumothorax', 'Pleural_Effusion', 'Pleural_Other',
        'Fracture', 'Support_Devices'
    ]
    print(f'dataloaders are set. train count: {np_train_samples.shape[0]}')
    logging.info("[LOGGING TEST]: dataloaders are set...")
    summary_train = {'epoch': 0, 'step': 0}
    summary_dev = {'loss': float('inf'), 'acc': 0.0}
    summary_writer = SummaryWriter(args.save_path)
    epoch_start = 0
    best_dict = {
        "acc_dev_best": 0.0,
        "auc_dev_best": 0.0,
        "loss_dev_best": float('inf'),
        "fused_dev_best": 0.0,
        "best_idx": 1
    }

    if args.resume:
        ckpt_path = os.path.join(args.save_path, 'train.ckpt')
        ckpt = torch.load(ckpt_path, map_location=device)
        model.module.load_state_dict(ckpt['state_dict'])
        summary_train = {'epoch': ckpt['epoch'], 'step': ckpt['step']}
        best_dict['acc_dev_best'] = ckpt['acc_dev_best']
        best_dict['loss_dev_best'] = ckpt['loss_dev_best']
        best_dict['auc_dev_best'] = ckpt['auc_dev_best']
        epoch_start = ckpt['epoch']

    q_list = []
    k_list = []
    for i in range(len(cfg.num_classes)):
        q_list.append(args.q)
        k_list.append(args.k)

    k_list = torch.FloatTensor(k_list)
    q_list = torch.FloatTensor(q_list)
    loss_sq_hinge = MultiClassSquaredHingeLoss()
    print('Everything is set starting to train...')
    before = datetime.datetime.now()
    for epoch in range(epoch_start, cfg.epoch):
        lr = lr_schedule(cfg.lr, cfg.lr_factor, summary_train['epoch'],
                         cfg.lr_epochs)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        summary_train, best_dict = train_epoch(summary_train, summary_dev, cfg,
                                               args, model, dataloader_train,
                                               dataloader_dev, optimizer,
                                               summary_writer, best_dict,
                                               dev_header, q_list, k_list,
                                               loss_sq_hinge)

        time_now = time.time()
        summary_dev, predlist, true_list = test_epoch(summary_dev, cfg, args,
                                                      model, dataloader_dev,
                                                      q_list, k_list,
                                                      loss_sq_hinge)
        time_spent = time.time() - time_now

        auclist = []
        for i in range(len(cfg.num_classes)):
            y_pred = predlist[i]
            y_true = true_list[i]
            fpr, tpr, thresholds = metrics.roc_curve(y_true,
                                                     y_pred,
                                                     pos_label=1)
            auc = metrics.auc(fpr, tpr)
            auclist.append(auc)
        summary_dev['auc'] = np.array(auclist)

        loss_dev_str = ' '.join(
            map(lambda x: '{:.5f}'.format(x), summary_dev['loss']))
        acc_dev_str = ' '.join(
            map(lambda x: '{:.3f}'.format(x), summary_dev['acc']))
        auc_dev_str = ' '.join(
            map(lambda x: '{:.3f}'.format(x), summary_dev['auc']))

        logging.info('{}, Dev, Step : {}, Loss : {}, Acc : {}, Auc : {},'
                     'Mean auc: {:.3f} '
                     'Run Time : {:.2f} sec'.format(
                         time.strftime("%Y-%m-%d %H:%M:%S"),
                         summary_train['step'], loss_dev_str, acc_dev_str,
                         auc_dev_str, summary_dev['auc'].mean(), time_spent))

        for t in range(len(cfg.num_classes)):
            summary_writer.add_scalar('dev/loss_{}'.format(dev_header[t]),
                                      summary_dev['loss'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/acc_{}'.format(dev_header[t]),
                                      summary_dev['acc'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/auc_{}'.format(dev_header[t]),
                                      summary_dev['auc'][t],
                                      summary_train['step'])

        save_best = False

        mean_acc = summary_dev['acc'][cfg.save_index].mean()
        if mean_acc >= best_dict['acc_dev_best']:
            best_dict['acc_dev_best'] = mean_acc
            if cfg.best_target == 'acc':
                save_best = True

        mean_auc = summary_dev['auc'][cfg.save_index].mean()
        if mean_auc >= best_dict['auc_dev_best']:
            best_dict['auc_dev_best'] = mean_auc
            if cfg.best_target == 'auc':
                save_best = True

        mean_loss = summary_dev['loss'][cfg.save_index].mean()
        if mean_loss <= best_dict['loss_dev_best']:
            best_dict['loss_dev_best'] = mean_loss
            if cfg.best_target == 'loss':
                save_best = True

        if save_best:
            torch.save(
                {
                    'epoch': summary_train['epoch'],
                    'step': summary_train['step'],
                    'acc_dev_best': best_dict['acc_dev_best'],
                    'auc_dev_best': best_dict['auc_dev_best'],
                    'loss_dev_best': best_dict['loss_dev_best'],
                    'state_dict': model.module.state_dict()
                },
                os.path.join(args.save_path,
                             'best{}.ckpt'.format(best_dict['best_idx'])))
            best_dict['best_idx'] += 1
            if best_dict['best_idx'] > cfg.save_top_k:
                best_dict['best_idx'] = 1
            logging.info('{}, Best, Step : {}, Loss : {}, Acc : {},'
                         'Auc :{},Best Auc : {:.3f}'.format(
                             time.strftime("%Y-%m-%d %H:%M:%S"),
                             summary_train['step'], loss_dev_str, acc_dev_str,
                             auc_dev_str, best_dict['auc_dev_best']))
        torch.save(
            {
                'epoch': summary_train['epoch'],
                'step': summary_train['step'],
                'acc_dev_best': best_dict['acc_dev_best'],
                'auc_dev_best': best_dict['auc_dev_best'],
                'loss_dev_best': best_dict['loss_dev_best'],
                'state_dict': model.module.state_dict()
            }, os.path.join(args.save_path, 'train.ckpt'))

        print_remaining_time(before,
                             epoch + 1,
                             cfg.epoch,
                             additional='[training]')
    summary_writer.close()
def main_worker(gpu, ngpus_per_node, args, config):
    set_seed(**config["seed"])
    logger = get_loguru_logger(args.log_dir, resume=args.resume, is_rank0=(gpu == 0))
    start_time = time.asctime(time.localtime(time.time()))
    logger.info("Start at: {} at: {}".format(start_time, platform.node()))
    torch.cuda.set_device(gpu)
    if args.distributed:
        args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(
            backend="nccl",
            init_method="tcp://127.0.0.1:{}".format(args.dist_port),
            world_size=args.world_size,
            rank=args.rank,
        )
        logger.warning("Only log rank 0 in distributed training!")

    logger.info("===Prepare data===")
    if "torch_transforms" in config:
        train_transform = TorchTransforms(config["torch_transforms"]["train"])
        test_transform = TorchTransforms(config["torch_transforms"]["test"])
    else:
        train_transform, test_transform = None, None
    logger.info("Torch training transformations:\n{}".format(train_transform))
    logger.info("Torch test transformations:\n{}".format(test_transform))
    logger.info("Load dataset from: {}".format(config["dataset_dir"]))
    train_data = get_dataset(config["dataset_dir"], train_transform)
    test_data = get_dataset(config["dataset_dir"], test_transform, train=False)
    prefetch = "prefetch" in config and config["prefetch"]
    logger.info("Prefetch: {}".format(prefetch))
    if args.distributed:
        train_sampler = DistributedSampler(train_data)
        # Divide batch size equally among multiple GPUs,
        # to keep the same learning rate used in a single GPU.
        batch_size = int(config["loader"]["batch_size"] / ngpus_per_node)
        num_workers = config["loader"]["num_workers"]
        train_loader = get_loader(
            train_data,
            prefetch=prefetch,
            batch_size=batch_size,
            sampler=train_sampler,
            num_workers=num_workers,
        )
    else:
        train_sampler = None
        train_loader = get_loader(
            train_data, prefetch=prefetch, loader_config=config["loader"], shuffle=True
        )
    test_loader = get_loader(
        test_data, prefetch=prefetch, loader_config=config["loader"]
    )

    logger.info("\n===Setup training===")
    model = get_network(config["network"])
    logger.info("Create network: {}".format(config["network"]))
    model = model.cuda(gpu)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda(gpu)
    logger.info("Create criterion: {}".format(criterion))
    optimizer = get_optimizer(model, config["optimizer"])
    logger.info("Create optimizer: {}".format(optimizer))
    scheduler = get_scheduler(optimizer, config["lr_scheduler"])
    logger.info("Create scheduler: {}".format(config["lr_scheduler"]))
    resumed_epoch, best_acc, best_epoch = resume_state(
        model,
        args.resume,
        args.ckpt_dir,
        logger,
        optimizer=optimizer,
        scheduler=scheduler,
        is_best=True,
    )
    if args.distributed:
        # Convert BatchNorm*D layer to SyncBatchNorm before wrapping Network with DDP.
        if "sync_bn" in config and config["sync_bn"]:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
            logger.info("Turn on synchronized batch normalization in ddp.")
        model = DistributedDataParallel(model, device_ids=[gpu])

    for epoch in range(config["num_epochs"] - resumed_epoch):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        logger.info(
            "===Epoch: {}/{}===".format(epoch + resumed_epoch + 1, config["num_epochs"])
        )
        logger.info("Training...")
        train_result = train(
            model,
            train_loader,
            criterion,
            optimizer,
            logger,
            amp=args.amp,
        )
        logger.info("Test...")
        test_result = test(model, test_loader, criterion, logger)

        if scheduler is not None:
            scheduler.step()
            logger.info(
                "Adjust learning rate to {}".format(optimizer.param_groups[0]["lr"])
            )

        # Save result and checkpoint.
        if not args.distributed or (args.distributed and gpu == 0):
            result = {"train": train_result, "test": test_result}
            result2csv(result, args.log_dir)

            saved_dict = {
                "epoch": epoch + resumed_epoch + 1,
                "result": result,
                "optimizer_state_dict": optimizer.state_dict(),
                "best_acc": best_acc,
                "best_epoch": best_epoch,
            }
            if not "parallel" in str(type(model)):
                saved_dict["model_state_dict"] = model.state_dict()
            else:
                # DP or DDP.
                saved_dict["model_state_dict"] = model.module.state_dict()
            if scheduler is not None:
                saved_dict["scheduler_state_dict"] = scheduler.state_dict()

            is_best = False
            if test_result["acc"] > best_acc:
                is_best = True
                best_acc = test_result["acc"]
                best_epoch = epoch + resumed_epoch + 1
            logger.info(
                "Best test accuaracy {} in epoch {}".format(best_acc, best_epoch)
            )
            if is_best:
                ckpt_path = os.path.join(args.ckpt_dir, "best_model.pt")
                torch.save(saved_dict, ckpt_path)
                logger.info("Save the best model to {}".format(ckpt_path))
            ckpt_path = os.path.join(args.ckpt_dir, "latest_model.pt")
            torch.save(saved_dict, ckpt_path)
            logger.info("Save the latest model to {}".format(ckpt_path))

    end_time = time.asctime(time.localtime(time.time()))
    logger.info("End at: {} at: {}".format(end_time, platform.node()))
Esempio n. 7
0
def run_fl(args):
    with open(args.cfg_path) as f:
        cfg = edict(json.load(f))
        if args.verbose is True:
            print(json.dumps(cfg, indent=4))

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    if args.logtofile is True:
        logging.basicConfig(filename=args.save_path + '/log.txt',
                            filemode="w",
                            level=logging.INFO)
    else:
        logging.basicConfig(level=logging.INFO)

    if not args.resume:
        with open(os.path.join(args.save_path, 'cfg.json'), 'w') as f:
            json.dump(cfg, f, indent=1)

    device_ids = list(map(int, args.device_ids.split(',')))
    num_devices = torch.cuda.device_count()
    if num_devices < len(device_ids):
        raise Exception('#available gpu : {} < --device_ids : {}'.format(
            num_devices, len(device_ids)))
    device = torch.device('cuda:{}'.format(device_ids[0]))

    # initialise global model
    model = Classifier(cfg).to(device).train()

    if args.verbose is True:
        from torchsummary import summary
        if cfg.fix_ratio:
            h, w = cfg.long_side, cfg.long_side
        else:
            h, w = cfg.height, cfg.width
        summary(model.to(device), (3, h, w))

    if args.pre_train is not None:
        if os.path.exists(args.pre_train):
            ckpt = torch.load(args.pre_train, map_location=device)
            model.load_state_dict(ckpt)

    src_folder = os.path.dirname(os.path.abspath(__file__)) + '/../'
    dst_folder = os.path.join(args.save_path, 'classification')
    rc, size = subprocess.getstatusoutput('du --max-depth=0 %s | cut -f1' %
                                          src_folder)

    if rc != 0:
        raise Exception('Copy folder error : {}'.format(rc))
    else:
        print('Successfully determined size of directory')

    rc, err_msg = subprocess.getstatusoutput('cp -R %s %s' %
                                             (src_folder, dst_folder))
    if rc != 0:
        raise Exception('copy folder error : {}'.format(err_msg))
    else:
        print('Successfully copied folder')

    # copy train files
    train_files = cfg.train_csv
    clients = {}
    for i, c in enumerate(string.ascii_uppercase):
        if i < len(train_files):
            clients[c] = {}
        else:
            break

    # initialise clients
    for i, client in enumerate(clients):
        copyfile(train_files[i],
                 os.path.join(args.save_path, f'train_{client}.csv'))
        clients[client]['dataloader_train'] =\
            DataLoader(
                ImageDataset(train_files[i], cfg, mode='train'),
                batch_size=cfg.train_batch_size,
                num_workers=args.num_workers,drop_last=True,
                shuffle=True
            )
        clients[client]['bytes_uploaded'] = 0.0
        clients[client]['epoch'] = 0
    copyfile(cfg.dev_csv, os.path.join(args.save_path, 'dev.csv'))

    dataloader_dev = DataLoader(ImageDataset(cfg.dev_csv, cfg, mode='dev'),
                                batch_size=cfg.dev_batch_size,
                                num_workers=args.num_workers,
                                drop_last=False,
                                shuffle=False)
    dev_header = dataloader_dev.dataset._label_header

    w_global = model.state_dict()

    summary_train = {'epoch': 0, 'step': 0}
    summary_dev = {'loss': float('inf'), 'acc': 0.0}
    summary_writer = SummaryWriter(args.save_path)
    comm_rounds = cfg.epoch
    best_dict = {
        "acc_dev_best": 0.0,
        "auc_dev_best": 0.0,
        "loss_dev_best": float('inf'),
        "fused_dev_best": 0.0,
        "best_idx": 1
    }

    # Communication rounds loop
    for cr in range(comm_rounds):
        logging.info('{}, Start communication round {} of FL - {} ...'.format(
            time.strftime("%Y-%m-%d %H:%M:%S"), cr + 1, cfg.fl_technique))

        w_locals = []

        for client in clients:

            logging.info(
                '{}, Start local training process for client {}, communication round: {} ...'
                .format(time.strftime("%Y-%m-%d %H:%M:%S"), client, cr + 1))

            # Load previous current global model as start point
            model = Classifier(cfg).to(device).train()

            model.load_state_dict(w_global)

            if cfg.fl_technique == "FedProx":
                global_weight_collector = get_global_weights(model, device)
            else:
                global_weight_collector = None

            optimizer = get_optimizer(model.parameters(), cfg)

            # local training loops
            for epoch in range(cfg.local_epoch):
                lr = lr_schedule(cfg.lr, cfg.lr_factor, epoch, cfg.lr_epochs)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

                summary_train, best_dict = train_epoch_fl(
                    summary_train, summary_dev, cfg, args, model,
                    clients[client]['dataloader_train'], dataloader_dev,
                    optimizer, summary_writer, best_dict, dev_header, epoch,
                    global_weight_collector)

                summary_train['step'] += 1

            bytes_to_upload = sys.getsizeof(model.state_dict())
            clients[client]['bytes_uploaded'] += bytes_to_upload
            logging.info(
                '{}, Completed local rounds for client {} in communication round {}. '
                'Uploading {} bytes to server, {} bytes in total sent from client'
                .format(time.strftime("%Y-%m-%d %H:%M:%S"), client, cr + 1,
                        bytes_to_upload, clients[client]['bytes_uploaded']))

            w_locals.append(model.state_dict())

        if cfg.fl_technique == "FedAvg":
            w_global = fed_avg(w_locals)
        elif cfg.fl_technique == 'WFedAvg':
            w_global = weighted_fed_avg(w_locals, cfg.train_proportions)
        elif cfg.fl_technique == 'FedProx':
            # Use weighted FedAvg when using FedProx
            w_global = weighted_fed_avg(w_locals, cfg.train_proportions)

        # Test the performance of the averaged model
        avged_model = Classifier(cfg).to(device)
        avged_model.load_state_dict(w_global)

        time_now = time.time()
        summary_dev, predlist, true_list = test_epoch(summary_dev, cfg, args,
                                                      avged_model,
                                                      dataloader_dev)
        time_spent = time.time() - time_now

        auclist = []
        for i in range(len(cfg.num_classes)):
            y_pred = predlist[i]
            y_true = true_list[i]
            fpr, tpr, thresholds = metrics.roc_curve(y_true,
                                                     y_pred,
                                                     pos_label=1)
            auc = metrics.auc(fpr, tpr)
            auclist.append(auc)
        auc_summary = np.array(auclist)

        loss_dev_str = ' '.join(
            map(lambda x: '{:.5f}'.format(x), summary_dev['loss']))
        acc_dev_str = ' '.join(
            map(lambda x: '{:.3f}'.format(x), summary_dev['acc']))
        auc_dev_str = ' '.join(map(lambda x: '{:.3f}'.format(x), auc_summary))

        logging.info(
            '{}, Averaged Model -> Dev, Step : {}, Loss : {}, Acc : {}, Auc : {},'
            'Mean auc: {:.3f} '
            'Run Time : {:.2f} sec'.format(time.strftime("%Y-%m-%d %H:%M:%S"),
                                           summary_train['step'], loss_dev_str,
                                           acc_dev_str, auc_dev_str,
                                           auc_summary.mean(), time_spent))

        for t in range(len(cfg.num_classes)):
            summary_writer.add_scalar('dev/loss_{}'.format(dev_header[t]),
                                      summary_dev['loss'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/acc_{}'.format(dev_header[t]),
                                      summary_dev['acc'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/auc_{}'.format(dev_header[t]),
                                      auc_summary[t], summary_train['step'])

        save_best = False

        mean_acc = summary_dev['acc'][cfg.save_index].mean()
        if mean_acc >= best_dict['acc_dev_best']:
            best_dict['acc_dev_best'] = mean_acc
            if cfg.best_target == 'acc':
                save_best = True

        mean_auc = auc_summary[cfg.save_index].mean()
        if mean_auc >= best_dict['auc_dev_best']:
            best_dict['auc_dev_best'] = mean_auc
            if cfg.best_target == 'auc':
                save_best = True

        mean_loss = summary_dev['loss'][cfg.save_index].mean()
        if mean_loss <= best_dict['loss_dev_best']:
            best_dict['loss_dev_best'] = mean_loss
            if cfg.best_target == 'loss':
                save_best = True

        if save_best:
            torch.save(
                {
                    'epoch': summary_train['epoch'],
                    'step': summary_train['step'],
                    'acc_dev_best': best_dict['acc_dev_best'],
                    'auc_dev_best': best_dict['auc_dev_best'],
                    'loss_dev_best': best_dict['loss_dev_best'],
                    'state_dict': avged_model.state_dict()
                },
                os.path.join(args.save_path,
                             'best{}.ckpt'.format(best_dict['best_idx'])))

            best_dict['best_idx'] += 1
            if best_dict['best_idx'] > cfg.save_top_k:
                best_dict['best_idx'] = 1
            logging.info('{}, Best, Step : {}, Loss : {}, Acc : {},'
                         'Auc :{},Best Auc : {:.3f}'.format(
                             time.strftime("%Y-%m-%d %H:%M:%S"),
                             summary_train['step'], loss_dev_str, acc_dev_str,
                             auc_dev_str, best_dict['auc_dev_best']))
        torch.save(
            {
                'epoch': cr,
                'step': summary_train['step'],
                'acc_dev_best': best_dict['acc_dev_best'],
                'auc_dev_best': best_dict['auc_dev_best'],
                'loss_dev_best': best_dict['loss_dev_best'],
                'state_dict': avged_model.state_dict()
            }, os.path.join(args.save_path, 'train.ckpt'))
Esempio n. 8
0
device = torch.device(f"cuda:{device_ids[0]}")

model = Classifier(cfg)
if args.verbose:
    from torchsummary import summary
    h, w = (cfg.long_side, cfg.long_side) if cfg.fix_ratio \
           else (cfg.height, cfg.width)
    summary(model.to(device), (3, h, w))

model = DataParallel(model, device_ids=device_ids).to(device)
if args.pre_train is not None:
    if exists(args.pre_train):
        ckpt = torch.load(args.pre_train, map_location=device)
        model.module.load_state_dict(ckpt)

optimizer = get_optimizer(model.parameters(), cfg)

trainset = ImageDataset(cfg.train_csv, cfg, mode='train')
testset = ImageDataset(cfg.dev_csv, cfg, mode='val')

trainloader = DataLoader(trainset, batch_size=cfg.train_batch_size,
    num_workers=args.num_workers, drop_last=True, shuffle=True)
testloader = DataLoader(testset, batch_size=cfg.dev_batch_size,
    num_workers=args.num_workers, drop_last=False, shuffle=False)

dev_header = testloader.dataset._label_header

# Initialize parameters to log training output

summary_train = {'epoch': 0, 'step': 0}
summary_dev = {'loss': float('inf'), 'acc': 0.0}
Esempio n. 9
0
def train(config):
    # -- read hyperparameters --
    log.info("Hyper-parameters:")
    device = get_and_log(config, "device", -1)
    out_dir = get_and_log(config, "out_dir", "./output")
    config_dir = get_and_log(config, "config_dir", "./config")
    dataset_dir = get_and_log(config, "dataset_dir", "./dataset")
    validation_idxs_filepath = get_and_log(config, "train_validation_split")
    dataset_name = get_and_log(config, "dataset", required=True)
    atomic_nums = get_and_log(config, "atom_id_to_atomic_num", required=True)
    batch_size = get_and_log(config, "batch_size", required=True)
    num_epoch = get_and_log(config, "num_epoch", required=True)
    word_size = get_and_log(config, "embed_word_size", required=True)
    molecule_size = get_and_log(config, "molecule_size", required=True)
    num_atom_type = get_and_log(config, "num_atom_type", required=True)
    save_epoch = get_and_log(config, "save_epoch", -1)
    kekulized = get_and_log(config, "kekulize", False)
    layers = get_and_log(config, "layers", required=True)
    scale_adj = get_and_log(config, "scale_adj", True)
    log_name = get_and_log(config, "log_name", "log")
    optimizer_type = get_and_log(config, "optimizer", "adam")
    optimizer_params = get_and_log(config, "optimizer_params")
    snapshot = get_and_log(config, "snapshot")
    num_edge_type = 4 if kekulized else 5
    
    os.makedirs(out_dir, exist_ok=True)

    if validation_idxs_filepath is not None:
        validation_idxs = get_validation_idxs(os.path.join(config_dir, validation_idxs_filepath))
    else:
        validation_idxs = None

    # -- build dataset --
    dataset = NumpyTupleDataset.load(os.path.join(dataset_dir, dataset_name))
    if validation_idxs:
        train_idxs = [i for i in range(len(dataset)) if i not in validation_idxs]
        trainset_size = len(train_idxs)
        train_idxs.extend(validation_idxs)
        trainset, testset = chainer.datasets.split_dataset(dataset, trainset_size, train_idxs)
    else:
        trainset, testset = chainer.datasets.split_dataset_random(dataset, int(len(dataset) * 0.8), seed=777)
    
    train_iter = chainer.iterators.SerialIterator(trainset, batch_size, shuffle=True)
    test_iter = chainer.iterators.SerialIterator(testset, batch_size, repeat=False, shuffle=False)
    
    # -- model --
    model = AtomEmbedModel(word_size, num_atom_type, num_edge_type,
                           layers, scale_adj)
    model.save_hyperparameters(os.path.join(out_dir, "atom_embed_model_hyper.json"))
    
    # -- training details --
    if device >= 0:
        log.info("Using GPU")
        chainer.cuda.get_device(device).use()
        model.to_gpu(device)

    opt_func = get_optimizer(optimizer_type)
    if optimizer_params is not None:
        optimizer = opt_func(optimizer_params)
    else:
        optimizer = opt_func()
    
    optimizer.setup(model)
    updater = AtomEmbedUpdater(train_iter, optimizer, device=device)
    trainer = training.Trainer(updater, (num_epoch, "epoch"), out=out_dir)
    save_epoch = save_epoch if save_epoch >= 0 else num_epoch
    
    # -- trainer extension --
    trainer.extend(extensions.snapshot, trigger=(save_epoch, "epoch"))
    trainer.extend(extensions.LogReport(filename=log_name))
    trainer.extend(AtomEmbedEvaluator(test_iter, model, reporter=trainer.reporter, device=device))
    trainer.extend(extensions.PrintReport(["epoch", "ce_loss", "accuracy", "validation/ce_loss", "validation/accuracy", "elapsed_time"]))
    trainer.extend(extensions.PlotReport(["ce_loss", "validation/ce_loss"], x_key="epoch", filename="cross_entrypy_loss.png"))
    trainer.extend(extensions.PlotReport(["accuracy", "validation/accuracy"], x_key="epoch", filename="accuracy.png"))

    if snapshot is not None:
        chainer.serializers.load_npz(snapshot, trainer)
    trainer.run()
    chainer.serializers.save_npz(os.path.join(out_dir, "final_embed_model.npz"), model)
Esempio n. 10
0
    def train(self,
              train_loader,
              val_loader,
              epochs=10,
              iter_log=100,
              use_lr_sch=False,
              resume=False,
              ckp_dir='./experiment/checkpoint',
              eval_metric='loss'):
        """Run training

        Args:
            train_loader (torch.utils.data.Dataloader): dataloader use for training
            val_loader (torch.utils.data.Dataloader): dataloader use for validation
            epochs (int, optional): number of training epochs. Defaults to 120.
            iter_log (int, optional): logging iteration. Defaults to 100.
            use_lr_sch (bool, optional): use learning rate scheduler. Defaults to False.
            resume (bool, optional): resume training process. Defaults to False.
            ckp_dir (str, optional): path to checkpoint directory. Defaults to './experiment/checkpoint'.
            writer (torch.utils.tensorboard.SummaryWriter, optional): tensorboard summery writer. Defaults to None.
            eval_metric (str, optional): name of metric for validation. Defaults to 'loss'.
        """
        wandb.init(name=self.cfg.log_dir,
                   project='Pediatric Multi-label Classifier',
                   entity='dolphin')

        optimizer = get_optimizer(self.model.parameters(), self.cfg)

        if use_lr_sch:
            lr_sch = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                       lr_lambda=lrfn)
            lr_hist = []
        else:
            lr_sch = None
        best_metric = 0.0

        if os.path.exists(ckp_dir) != True:
            os.mkdir(ckp_dir)
        if resume:
            epoch_resume, iter_resume = self.load_ckp(
                os.path.join(ckp_dir, 'latest.ckpt'))
        else:
            epoch_resume = 1
            iter_resume = 0
        scaler = None
        if self.cfg.mix_precision:
            print('Train with mix precision!')
            scaler = torch.cuda.amp.GradScaler()
        for epoch in range(epoch_resume - 1, epochs):
            start = time.time()
            running_loss = AverageMeter()
            n_iter = len(train_loader)
            torch.set_grad_enabled(True)
            self.model.train()
            batch_weights = (1 / iter_log) * np.ones(n_iter)
            step_per_epoch = n_iter // iter_log
            if n_iter % iter_log:
                step_per_epoch += 1
                batch_weights[-(n_iter % iter_log):] = 1 / (n_iter % iter_log)
                iter_per_step = iter_log * \
                    np.ones(step_per_epoch, dtype=np.int16)
                iter_per_step[-1] = n_iter % iter_log
            else:
                iter_per_step = iter_log * \
                    np.ones(step_per_epoch, dtype=np.int16)
            i = 0
            for step in range(step_per_epoch):
                loop = tqdm.tqdm(range(iter_per_step[step]),
                                 total=iter_per_step[step])
                iter_loader = iter(train_loader)
                for iteration in loop:
                    data = next(iter_loader)
                    imgs, labels = data[0].to(self.device), data[1].to(
                        self.device)

                    if self.cfg.mix_precision:
                        with torch.cuda.amp.autocast():
                            preds = self.model(imgs)
                            loss = self.metrics['loss'](preds, labels)

                    else:
                        preds = self.model(imgs)
                        loss = self.metrics['loss'](preds, labels)

                    preds = nn.Sigmoid()(preds)
                    running_loss.update(loss.item(), imgs.shape[0])
                    optimizer.zero_grad()
                    if self.cfg.mix_precision:
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        loss.backward()
                        optimizer.step()
                    i += 1

                if wandb:
                    wandb.log({'loss/train': running_loss.avg},
                              step=(epoch * n_iter) + (i + 1))
                s = "Epoch [{}/{}] Iter [{}/{}]:\n".format(
                    epoch + 1, epochs, i + 1, n_iter)
                s += "{}_{} {:.3f}\n".format('train', 'loss', running_loss.avg)
                running_metrics_test = self.test(val_loader, False)
                torch.set_grad_enabled(True)
                self.model.train()
                s = get_str(running_metrics_test, 'val', s)
                if wandb:
                    for key in running_metrics_test.keys():
                        if key != 'loss':
                            for j, disease_class in enumerate(
                                    np.array(
                                        train_loader.dataset.disease_classes)):
                                wandb.log(
                                    {
                                        key + '/' + disease_class:
                                        running_metrics_test[key][j]
                                    },
                                    step=(epoch * n_iter) + (i + 1))
                        else:
                            wandb.log(
                                {'loss/val': running_metrics_test['loss']},
                                step=(epoch * n_iter) + (i + 1))
                if self.cfg.type != 'chexmic':
                    metric_eval = running_metrics_test[eval_metric]
                else:
                    metric_eval = running_metrics_test[eval_metric][
                        self.id_obs]
                s = s[:-1] + "- mean_"+eval_metric + \
                    " {:.3f}".format(metric_eval.mean())
                self.save_ckp(os.path.join(ckp_dir, 'latest.ckpt'), epoch, i)
                running_loss.reset()
                end = time.time()
                s += " ({:.1f}s)".format(end - start)
                print(s)
                if metric_eval.mean() > best_metric:
                    best_metric = metric_eval.mean()
                    shutil.copyfile(os.path.join(ckp_dir, 'latest.ckpt'),
                                    os.path.join(ckp_dir, 'best.ckpt'))
                    print('new checkpoint saved!')
                start = time.time()
            if lr_sch is not None:
                lr_sch.step()
                print('current lr: {:.4f}'.format(lr_sch.get_lr()[0]))
        if lr_sch is not None:
            return lr_hist
        else:
            return None