コード例 #1
0
ファイル: tests.py プロジェクト: pombredanne/iterstuff
    def test_lookahead(self):
        l = Lookahead([])
        self.assertTrue(l.atstart)
        self.assertTrue(l.atend)
        self.assertIsNone(l.peek)
        self.assertEqual(len(list(l)), 0)

        l = Lookahead("a")
        self.assertTrue(l.atstart)
        self.assertFalse(l.atend)
        self.assertEqual(l.peek, "a")
        self.assertEqual(l.next(), "a")
        self.assertFalse(l.atstart)
        self.assertTrue(l.atend)
        self.assertRaises(StopIteration, l.next)

        l = Lookahead(range(10))
        self.assertTrue(l.atstart)
        self.assertFalse(l.atend)
        self.assertEqual(l.peek, 0)
        self.assertEqual(l.next(), 0)
        self.assertEqual(l.next(), 1)
        self.assertEqual(l.peek, 2)
        self.assertFalse(l.atstart)
        self.assertFalse(l.atend)
        self.assertEqual(list(l), [2, 3, 4, 5, 6, 7, 8, 9])
        self.assertTrue(l.atend)
コード例 #2
0
    def __init__(self, model, lr, weight_decay, batch):
        self.model = model
        # w - L2 regularization ; b - not L2 regularization
        weight_p, bias_p = [], []

        for p in self.model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        for name, p in self.model.named_parameters():
            if 'bias' in name:
                bias_p += [p]
            else:
                weight_p += [p]
        # self.optimizer = optim.Adam([{'params': weight_p, 'weight_decay': weight_decay}, {'params': bias_p, 'weight_decay': 0}], lr=lr)
        self.optimizer_inner = RAdam([{
            'params': weight_p,
            'weight_decay': weight_decay
        }, {
            'params': bias_p,
            'weight_decay': 0
        }],
                                     lr=lr)
        self.optimizer = Lookahead(self.optimizer_inner, k=5, alpha=0.5)
        self.batch = batch
コード例 #3
0
    def build_model(self):
        """Build generator and discriminator."""
        self.set_seed(1)

        if self.model_type == 'UNet':
            self.unet = UNet()
        elif self.model_type == 'Squeeze_UNet':
            self.unet = Squeeze_UNet()
        elif self.model_type == 'Mobile_UNet':
            self.unet = Mobile_UNet()
        elif self.model_type == 'ShuffleV1_UNet':
            self.unet = ShuffleV1_UNet()
        elif self.model_type == 'ShuffleV2_UNet':
            self.unet = ShuffleV2_UNet()
        elif self.model_type == 'IGCV1_UNet':
            self.unet = IGCV1_UNet()

        self.optimizer = optim.Adam(list(self.unet.parameters()),
                                    self.lr, [self.beta1, self.beta2], weight_decay=0.00005)
        self.optimizer = Lookahead(self.optimizer, k=5, alpha=0.5)

        self.unet = nn.DataParallel(self.unet, self.device).cpu()

        self.init_weights(self.unet, init_type='kaiming')

        self.print_network(self.unet, self.model_type)
コード例 #4
0
ファイル: model.py プロジェクト: 17854212083/MSCANet
def train(model, data):
    x, y = data
    # 将数据划分为训练集合验证集
    train_x, valid_x, train_y, valid_y = train_test_split(x,
                                                          y,
                                                          test_size=0.2,
                                                          shuffle=True)
    # 设置检查点
    callbacks_list = [
        keras.callbacks.ModelCheckpoint(filepath=weight_path,
                                        monitor='loss',
                                        save_best_only=True),
        keras.callbacks.LearningRateScheduler(schedule),
    ]
    # 编译模型
    model.compile(optimizer=keras.optimizers.Adam(lr=lr), loss=total_loss)
    # 初始化Lookahead
    lookahead = Lookahead(k=5, alpha=0.5)
    # 插入到模型中
    lookahead.inject(model)
    # 开始时间
    start_time = time()
    # 训练模型
    if not data_augmentation:
        print("不使用数据分割")
        history = model.fit(train_x,
                            train_y,
                            epochs=epoch,
                            batch_size=batch_size,
                            validation_data=(valid_x, valid_y),
                            verbose=1,
                            callbacks=callbacks_list)
    else:
        print("使用数据分割")
        history = model.fit_generator(
            generator=data_generator(train_x, train_y, batch_size),
            steps_per_epoch=(len(train_x) + batch_size - 1) // batch_size,
            epochs=epoch,
            verbose=1,
            callbacks=callbacks_list,
            validation_data=data_generator(valid_x, valid_y, batch_size),
            validation_steps=(len(valid_x) + batch_size - 1) // batch_size)
    model.save(weight_path)  # 保存模型

    # 结束时间
    duration = time() - start_time
    print("Train Finished takes:", "{:.2f} h".format(duration / 3600.0))
    plt.plot(history.history['loss'], label='train')
    plt.plot(history.history['val_loss'], label='valid')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend(loc='upper right')
    plt.show()
    return model
コード例 #5
0
ファイル: tests.py プロジェクト: pombredanne/iterstuff
    def test_lookahead(self):
        l = Lookahead([])
        self.assertTrue(l.atstart)
        self.assertTrue(l.atend)
        self.assertIsNone(l.peek)
        self.assertEqual(len(list(l)), 0)

        l = Lookahead("a")
        self.assertTrue(l.atstart)
        self.assertFalse(l.atend)
        self.assertEqual(l.peek, 'a')
        self.assertEqual(l.next(), 'a')
        self.assertFalse(l.atstart)
        self.assertTrue(l.atend)
        self.assertRaises(StopIteration, l.next)

        l = Lookahead(range(10))
        self.assertTrue(l.atstart)
        self.assertFalse(l.atend)
        self.assertEqual(l.peek, 0)
        self.assertEqual(l.next(), 0)
        self.assertEqual(l.next(), 1)
        self.assertEqual(l.peek, 2)
        self.assertFalse(l.atstart)
        self.assertFalse(l.atend)
        self.assertEqual(list(l), [2, 3, 4, 5, 6, 7, 8, 9])
        self.assertTrue(l.atend)
コード例 #6
0
ファイル: recipes.py プロジェクト: pombredanne/iterstuff
def batch(iterable, size):
    """
    Yield iterables for successive slices of `iterable`, each containing
    up to `size` items, with the last being less than `size` if there are
    not sufficient items in `iterable`. Pass over the input iterable once
    only. Yield iterables, not lists.

    @note: each output iterable must be consumed in full before the next
     one is yielded. So list(batch(xrange(10), 3)) won't work as expected,
     because the iterables are not consumed.

    @param iterable: an input iterable.
    @param size: the maximum number of items yielded by any output iterable.
    """
    # Wrap an enumeration of the iterable in a Lookahead so that it
    # yields (count, element) tuples
    it = Lookahead(enumerate(iterable))

    while not it.atend:
        # Set the end_count using the count value
        # of the next element.
        end_count = it.peek[0] + size

        # Yield a generator that will then yield up to
        # 'size' elements from 'it'.
        yield (
            element
            for counter, element in repeatable_takewhile(
                # t[0] is the count part of each element
                lambda t: t[0] < end_count,
                it
            )
        )
コード例 #7
0
    def __init__(self, args, state_dim, action_dim, action_num):
        mod_dim = args.mod_dim
        self.update_counter = 0
        self.update_gap = args.update_gap
        self.tau = args.soft_update_tau
        self.policy_noise = args.policy_noise
        self.gamma = args.gamma

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.action_num = action_num
        self.state_idx = 1 + 1 + state_dim  # reward_dim==1, done_dim==1, state_dim
        self.action_idx = self.state_idx + action_dim * action_num

        from torch import optim
        from lookahead import Lookahead
        self.act = Actor(mod_dim, state_dim, action_dim,
                         action_num).to(self.device)
        # self.act_optimizer = optim.Adam(self.act.parameters(), lr=4e-4)
        self.act_optimizer = Lookahead(optim.Adam(self.act.parameters(),
                                                  lr=4e-4),
                                       k=5,
                                       alpha=0.5)
        self.act.train()

        self.act_target = Actor(mod_dim, state_dim, action_dim,
                                action_num).to(self.device)
        self.act_target.load_state_dict(self.act.state_dict())
        self.act_target.eval()

        self.cri = Critic(mod_dim, state_dim, action_dim).to(self.device)
        # self.cri_optimizer = optim.Adam(self.cri.parameters(), lr=1e-3)
        self.cri_optimizer = Lookahead(optim.Adam(self.cri.parameters(),
                                                  lr=1e-3),
                                       k=5,
                                       alpha=0.5)
        self.cri.train()

        self.cri_target = Critic(mod_dim, state_dim,
                                 action_dim).to(self.device)
        self.cri_target.load_state_dict(self.cri.state_dict())
        self.cri_target.eval()

        self.criterion = nn.SmoothL1Loss()
コード例 #8
0
    def __init__(self, net, netsavepath, datasetpath, filename):
        self.net = net.to(device)
        self.netsavepath = netsavepath
        if os.path.exists(self.netsavepath):
            self.net = torch.load(self.netsavepath).to(device)
        print(net.__class__.__name__)
        self.writer = SummaryWriter(
            log_dir='./runs/{}loss'.format(str(net.__class__.__name__[0])))
        self.datasetpath = datasetpath
        self.lossconf = nn.BCELoss()
        self.lossoffset = nn.MSELoss()
        self.iouloss = nn.MSELoss()
        self.filename = filename

        # self.optimizer = optim.Adam(self.net.parameters())
        self.optimizer = torch.optim.Adam(self.net.parameters(),
                                          lr=1e-3,
                                          betas=(0.9, 0.999))  # Any optimizer
        self.lookahead = Lookahead(self.optimizer, k=5,
                                   alpha=0.5)  # Initialize Lookahead
コード例 #9
0
ファイル: tests.py プロジェクト: pombredanne/iterstuff
    def test_repeatable_takewhile(self):
        data = Lookahead(x for x in 'abcd123ghi')

        self.assertEqual(
            list(repeatable_takewhile(lambda x: not x.isdigit(), data)),
            list('abcd'))

        self.assertEqual(
            list(repeatable_takewhile(lambda x: x.isdigit(), data)),
            list('123'))

        self.assertEqual(list(data), list('ghi'))
コード例 #10
0
def testLoadingAndAppendingSegments():
    la = Lookahead(59560.2, window_size=2000)
    la.load_segment()
    before = len(la.lookahead_mjds)
    la.load_segment()
    after = len(la.lookahead_mjds)
    assert before < after
コード例 #11
0
   def optimize(self, seqs, epochs, batch_size, ndel_func, init_lr=.01, min_lr=.001, lr_exp=1.01):

      # Optimizer
      base_opt = torch.optim.Adam(self.parameters(), lr=init_lr**(1/lr_exp))
      opt = Lookahead(base_opt, k=5, alpha=0.5)
      opt = base_opt
      #opt_sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.1, patience=10, verbose=True)

      # Logs
      train_loss = np.zeros(epochs)
      test_loss  = np.zeros(epochs)

      for e in torch.arange(epochs):
         # Update LR
         cur_lr = base_opt.param_groups[0]['lr']
         if cur_lr > min_lr:
            base_opt.param_groups[0]['lr'] = max(min_lr,cur_lr**lr_exp)

         for x in seqs.train_batches(batch_size=batch_size):
            # Compute one-hot encoded targets
            target = sequence.to_onehot(x, seqs.n_symbols, dtype=torch.float32)
            # Apply random deletions
            xdel = sequence.random_dels(x, ndel_func)
            # Attention
            y = self(xdel, None)
            loss = torch.nn.functional.binary_cross_entropy(y, target)
            mem = torch.cuda.memory_allocated()/1024/1024

            # Update parameters
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            l = float(loss.detach().to('cpu'))
            train_loss[e] += l/(seqs.train_seq.shape[0]//batch_size)

         # Test data
         with torch.no_grad():
            # Generate test batch
            x = seqs.random_test_batch(batch_size=10*batch_size)
            # Compute one-hot encoded targets
            target = sequence.to_onehot(x, seqs.n_symbols, dtype=torch.float32)
            # Apply random deletions
            xdel = sequence.random_dels(x, ndel_func)
            # Attention
            y = self(xdel, None)
            test_loss[e] = float(torch.nn.functional.binary_cross_entropy(y, target).to('cpu'))

         # Optimizer lr schedule
         #opt_sched.step(train_loss[e])

         # Verbose epoch
         print("[epoch {}] train_loss={:.3f} test_loss={:.3f} memory={:.2f}MB".format(e+1, 100*train_loss[:e+1].mean(), 100*test_loss[:e+1].mean(), mem))

         # Save model
         if (e+1)%10 == 0:
            torch.save(self, 'saved_models/delattn_model_epoch{}.torch'.format(e))
            

      return train_loss, test_loss
コード例 #12
0
def model_fn(objective, optimizer, metrics):
    base_model = efn.EfficientNetB4(
        include_top=False,
        # base_model = seresnext50(include_top=False,
        # base_model = xception(include_top=False,
        # base_model = densenet201(include_top=False,
        # base_model = inceptionresnetv2(include_top=False,
        input_shape=(input_size, input_size, 3),
        classes=num_classes,
        weights='imagenet',
    )
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    predictions = Dense(num_classes, activation='softmax')(x)
    model1 = Model(inputs=base_model.input, outputs=predictions)
    #     model2 = multi_gpu_model(model1, gpus=3)
    #     model2 = model1
    model1.compile(loss=objective, optimizer=optimizer, metrics=metrics)
    lookahead = Lookahead(k=5, alpha=0.5)  # Initialize Lookahead
    lookahead.inject(model1)  # add into model
    model1.summary()
    return model1
コード例 #13
0
def testMemoryUsage():
    '''one simulated year's worth of loading and trimming'''
    la = Lookahead(healpix=True, nSides=8)
    la.load_segment()
    while la.date < 59885:
        la.date += 1
        la.start_night()
    assert True
コード例 #14
0
ファイル: train.py プロジェクト: 12345fengce/AI_CV
    def __init__(self, mode: str, batch_size: int):
        "Device Config"
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        "Model Config"
        self.mode = mode
        if mode == "P" or mode == "p":
            self.size = 12
            self.threshold = 0.9
            self.net = net.PNet().to(self.device)
        elif mode == "R" or mode == "r":
            self.size = 24
            self.threshold = 0.99
            self.net = net.RNet().to(self.device)
        elif mode == "O" or mode == "o":
            self.size = 48
            self.threshold = 0.999
            self.net = net.ONet().to(self.device)
        if len(os.listdir("./params")) > 3:
            print("MODE: {} >>> Loading ... ...".format(mode))
            self.net.load_state_dict(torch.load("./params/{}net.pkl".format(mode.lower())))

        "Dataloader Config"
        self.train = data.DataLoader(dataset.choice(mode.lower()), batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
        self.test = data.DataLoader(dataset.choice("{}v".format(mode.lower())), batch_size=512, shuffle=True, drop_last=True)

        "Optim Config"
        optimize = optim.SGD(self.net.parameters(), lr=3e-4, momentum=0.9)
        self.lookahead = Lookahead(optimize, k=5, alpha=0.5)
        # self.lr = optim.lr_scheduler.CosineAnnealingLR(self.lookahead, T_max=1, eta_min=1e-5, last_epoch=-1)

        "Loss Config"
        self.loss_confi = nn.BCELoss()
        self.loss_resolve = nn.BCELoss(reduction="none")
        self.loss_offset = nn.MSELoss()

        "Show Config"
        self.summarywriter = SummaryWriter(log_dir="./runs/{}_runs".format(mode.lower()))
コード例 #15
0
def testDateIndex():
    '''test that when we ask for a date that isn't an exact value in the date table, we get 
    the closest available date'''
    la = Lookahead()
    la.load_segment()
    la.populate_lookahead_window()
    print(la.dateindex(59550.0278))
    print(la.date)
    print(la.lookahead_mjds[7])
    print(la.lookahead_mjds[8])
    print(la.lookahead_mjds[9])
    print(la.lookahead_mjds[10])
コード例 #16
0
ファイル: model.py プロジェクト: Samantha09/transformerCPI2
class Trainer(object):
    def __init__(self, model, lr, weight_decay, batch):
        self.model = model
        # w - L2 regularization ; b - not L2 regularization
        weight_p, bias_p = [], []

        for p in self.model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        for name, p in self.model.named_parameters():
            if 'bias' in name:
                bias_p += [p]
            else:
                weight_p += [p]
        # self.optimizer = optim.Adam([{'params': weight_p, 'weight_decay': weight_decay}, {'params': bias_p, 'weight_decay': 0}], lr=lr)
        self.optimizer_inner = RAdam(
            [{'params': weight_p, 'weight_decay': weight_decay}, {'params': bias_p, 'weight_decay': 0}], lr=lr)
        self.optimizer = Lookahead(self.optimizer_inner, k=5, alpha=0.5)
        self.batch = batch

    def train(self, dataset, device):
        self.model.train()
        # dataset_iter = iter(dataset)
        # np.random.shuffle(dataset)
        N = sum(1 for _ in deepcopy(dataset))
        # print("len of datasets: ", N)
        loss_total = 0
        i = 0
        self.optimizer.zero_grad()
        adjs, atoms, proteins, labels = [], [], [], []
        # TODO: 进度条
        for _ in tqdm(range(N), ascii=True):
            data = next(dataset)
            i = i+1
            atom, adj, protein, label = data
            # TODO: 将Tensor转移到显卡上
            if torch.cuda.is_available():
                atom, adj, protein, label = atom.cuda(), adj.cuda(), protein.cuda(), label.cuda()
            adjs.append(adj)
            atoms.append(atom)
            proteins.append(protein)
            labels.append(label)
            if i % 8 == 0 or i == N:
                data_pack = pack(atoms, adjs, proteins, labels, device)
                loss = self.model(data_pack)
                # loss = loss / self.batch
                loss.backward()
                # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10)
                adjs, atoms, proteins, labels = [], [], [], []
            else:
                continue
            if i % self.batch == 0 or i == N:
                self.optimizer.step()
                self.optimizer.zero_grad()
            loss_total += loss.item()

        return loss_total
コード例 #17
0
def train(num_epochs, model, data_loader, val_loader, val_every, device, file_name):
    learning_rate = 0.0001
    from torch.optim.swa_utils import AveragedModel, SWALR
    from torch.optim.lr_scheduler import CosineAnnealingLR
    from segmentation_models_pytorch.losses import SoftCrossEntropyLoss, JaccardLoss
    from adamp import AdamP

    criterion = [SoftCrossEntropyLoss(smooth_factor=0.1), JaccardLoss('multiclass', classes=12)]
    optimizer = AdamP(params=model.parameters(), lr=learning_rate, weight_decay=1e-6)
    swa_scheduler = SWALR(optimizer, swa_lr=learning_rate)
    swa_model = AveragedModel(model)
    look = Lookahead(optimizer, la_alpha=0.5)

    print('Start training..')
    best_miou = 0
    for epoch in range(num_epochs):
        hist = np.zeros((12, 12))
        model.train()
        for step, (images, masks, _) in enumerate(data_loader):
            loss = 0
            images = torch.stack(images)  # (batch, channel, height, width)
            masks = torch.stack(masks).long()  # (batch, channel, height, width)

            # gpu 연산을 위해 device 할당
            images, masks = images.to(device), masks.to(device)

            # inference
            outputs = model(images)
            for i in criterion:
                loss += i(outputs, masks)
            # loss 계산 (cross entropy loss)

            look.zero_grad()
            loss.backward()
            look.step()

            outputs = torch.argmax(outputs.squeeze(), dim=1).detach().cpu().numpy()
            hist = add_hist(hist, masks.detach().cpu().numpy(), outputs, n_class=12)
            acc, acc_cls, mIoU, fwavacc = label_accuracy_score(hist)
            # step 주기에 따른 loss, mIoU 출력
            if (step + 1) % 25 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, mIoU: {:.4f}'.format(
                    epoch + 1, num_epochs, step + 1, len(data_loader), loss.item(), mIoU))

        # validation 주기에 따른 loss 출력 및 best model 저장
        if (epoch + 1) % val_every == 0:
            avrg_loss, val_miou = validation(epoch + 1, model, val_loader, criterion, device)
            if val_miou > best_miou:
                print('Best performance at epoch: {}'.format(epoch + 1))
                print('Save model in', saved_dir)
                best_miou = val_miou
                save_model(model, file_name = file_name)

        if epoch > 3:
            swa_model.update_parameters(model)
            swa_scheduler.step()
コード例 #18
0
class Trainer(object):
    def __init__(self, model, lr, weight_decay, batch):
        self.model = model
        # w - L2 regularization ; b - not L2 regularization
        weight_p, bias_p = [], []

        for p in self.model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        for name, p in self.model.named_parameters():
            if 'bias' in name:
                bias_p += [p]
            else:
                weight_p += [p]
        # self.optimizer = optim.Adam([{'params': weight_p, 'weight_decay': weight_decay}, {'params': bias_p, 'weight_decay': 0}], lr=lr)
        self.optimizer_inner = RAdam([{
            'params': weight_p,
            'weight_decay': weight_decay
        }, {
            'params': bias_p,
            'weight_decay': 0
        }],
                                     lr=lr)
        self.optimizer = Lookahead(self.optimizer_inner, k=5, alpha=0.5)
        self.batch = batch

    def train(self, dataset, device):
        self.model.train()
        np.random.shuffle(dataset)
        N = len(dataset)
        loss_total = 0
        i = 0
        self.optimizer.zero_grad()
        adjs, atoms, proteins, labels = [], [], [], []
        for data in dataset:
            i = i + 1
            atom, adj, protein, label = data
            adjs.append(adj)
            atoms.append(atom)
            proteins.append(protein)
            labels.append(label)
            if i % 8 == 0 or i == N:
                data_pack = pack(atoms, adjs, proteins, labels, device)
                loss = self.model(data_pack)
                # loss = loss / self.batch
                loss.backward()
                # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10)
                adjs, atoms, proteins, labels = [], [], [], []
            else:
                continue
            if i % self.batch == 0 or i == N:
                self.optimizer.step()
                self.optimizer.zero_grad()
            loss_total += loss.item()
        return loss_total
コード例 #19
0
ファイル: train.py プロジェクト: yangyuren03/rsna-2019
def make_optimizer(args, model):
    lr = args.lr
    # weight_decay = 1e-2
    # weight_decay_bias = 0

    params_list = []
    if args.input_level == 'per-study':
        encoder, decoder = model
        # for m in model:
        #     for key, value in m.named_parameters():
        #         print(key)
        #         if not value.requires_grad:
        #             continue
        #         if 'bias' in key:
        #             # x2 lr for bias, and weight decay for bias is 0
        #             params_list.append({'params': [value], 'lr': lr*2, 'weight_decay': weight_decay_bias})
        #         else:
        #             # weight decay for weight is 1e-2
        #             params_list.append({'params': [value], 'lr': lr, 'weight_decay': weight_decay})

        params_list.append({'params': encoder.parameters()})
        params_list.append({'params': decoder.parameters()})
    else:
        params_list.append({'params': model.parameters()})

    if args.optim == 'adam':
        optimizer = torch.optim.Adam(params_list, lr=lr, eps=1e-3)
    elif args.optim == 'radam':
        from radam import RAdam
        optimizer = RAdam(params_list, lr=lr, eps=1e-3)
    elif args.optim == 'adamw':
        from radam import AdamW
        optimizer = AdamW(params_list, lr=lr, eps=1e-3)
    elif args.optim == 'sgd':
        # lr = lr*100 a dat suggest, sgd is much slower than adam
        optimizer = torch.optim.SGD(params_list, lr=lr, momentum=0.9)
    else:
        raise ValueError("Unkniown optimizer")

    if args.lookahead:
        from lookahead import Lookahead
        optimizer = Lookahead(optimizer)

    return optimizer
コード例 #20
0
ファイル: recipes.py プロジェクト: pombredanne/iterstuff
def chunked(i, f=lambda _x: _x):
    """
    Given an iterable i, apply f over it to extract a value from
    each element and yield successive iterables where the result
    of f for all elements is the same.

    In simpler language, if i is an iterable sorted on some key, yield
    chunks of that list where the key value is the same, each chunk being
    a separate iterable.

    Note that this function yields B{iterators}, not lists, and they refer
    back to the iterator passed in, so each B{must} be consumed completely
    before the next one is requested.

    @param i: an iterable.
    @param f: a function to be applied to each element of the iterable to
    extract the key.
    """
    # Build a generator that return tuples of (element, key-of-element),
    # so that we only apply the key method to each element once.
    it = Lookahead((_x, f(_x)) for _x in i)

    def takechunk():
        """
        A generator closure that will yield values while the keys remain
        the same. Note that we cannot use L{itertools.takewhile} for this,
        because that takes elements and B{then} checks the predicate, so
        successive calls to itertools.takewhile for the same generator will
        skip elements.
        """
        while True:
            # Always yield the first element: if we're at the end of the
            # generator, this will raise StopIteration and we're done.
            (_x, key) = it.next()
            yield _x

            # Check the lookahead's peek value to see if we should break now.
            # We also break when we're at the end of the generator.
            if it.atend or key != it.peek[1]:
                break

    # Yield successive instances of takechunk.
    while not it.atend:
        yield takechunk()
コード例 #21
0
def init_optimizer(args, model, criterion):
    params = list(model.parameters()) + list(criterion.parameters())
    total_params = sum(x.size()[0] *
                       x.size()[1] if len(x.size()) > 1 else x.size()[0]
                       for x in params if x.size())
    print('Args:', args)
    print('Model total parameters:', total_params)

    optimizer = None
    # Ensure the optimizer is optimizing params, which includes both the model's weights as well as the criterion's weight (i.e. Adaptive Softmax)
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params,
                                    lr=args.lr,
                                    weight_decay=args.wdecay)
    if args.optimizer == 'adagrad':
        optimizer = torch.optim.Adagrad(params,
                                        lr=args.lr,
                                        weight_decay=args.wdecay)
    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params,
                                     lr=args.lr,
                                     weight_decay=args.wdecay)
    if args.optimizer == 'adamw':
        optimizer = torch.optim.AdamW(params,
                                      lr=args.lr,
                                      weight_decay=args.wdecay)
    if args.optimizer == 'lamb':
        from pytorch_lamb import Lamb
        optimizer = Lamb(params,
                         lr=args.lr,
                         weight_decay=args.wdecay,
                         min_trust=0.25)
        # optimizer = Lamb(params, lr=args.lr, weight_decay=args.wdecay, min_trust=0.1)
        # optimizer = Lamb(params, lr=args.lr, weight_decay=args.wdecay, min_trust=0, random_min_trust=0.2, random_trust_dice=10)
        # optimizer = Lamb(params, lr=args.lr, weight_decay=args.wdecay, min_trust=0.2, random_min_trust=0.5, random_trust_dice=4)
    from lookahead import Lookahead
    if False:
        k, alpha = 5, 0.8
        print('Lookahead - k {} and alpha {}'.format(k, alpha))
        optimizer = Lookahead(base_optimizer=optimizer, k=k, alpha=alpha)

    model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    # model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
    return model, optimizer, params
コード例 #22
0
def testArrayShapes():
    '''makes sure the numpy arrays we're passing around have the right number
    of dimensions'''

    la = Lookahead()
    la.load_segment()

    #dates should be same length as sky map arrays
    assert la.lookahead[u'u'].shape[0] == la.lookahead_mjds.shape[0]

    #dates should be 1-dimensional, everything else should be 2-dimensional
    assert len(la.lookahead_mjds.shape) == 1
    for k in la.keys:
        assert len(la.lookahead[k].shape) == 2

    #trim the array, load some more data, and check again
    la.date += 5
    la.start_night()
    la.load_segment()

    assert la.lookahead[u'u'].shape[0] == la.lookahead_mjds.shape[0]
    assert len(la.lookahead_mjds.shape) == 1
    for k in la.keys:
        assert len(la.lookahead[k].shape) == 2
コード例 #23
0
def testStartNightArraySlicing():

    la = Lookahead(59560.2, window_size=30, healpix=True, nSides=8)
    la.load_segment()
    before = len(la.lookahead_mjds)
    la.date = 59561.2
    la.start_night()
    after = len(la.lookahead_mjds)
    print(before)
    print after

    assert before > after

    assert len(la.lookahead_mjds) == len(la.lookahead[u'u'])
    assert len(la.lookahead_mjds) == len(la.lookahead[u'g'])
    assert len(la.lookahead_mjds) == len(la.lookahead[u'r'])
    assert len(la.lookahead_mjds) == len(la.lookahead[u'i'])
    assert len(la.lookahead_mjds) == len(la.lookahead[u'z'])
    assert len(la.lookahead_mjds) == len(la.lookahead[u'y'])
    assert len(la.lookahead_mjds) == len(la.lookahead[u'moonangle'])
    assert len(la.lookahead_mjds) == len(la.lookahead[u'airmass'])
コード例 #24
0
ファイル: model.py プロジェクト: GuanJZ/Behavioral-Cloning
def nvidia(optimizer, source_path, train_generator, validation_generator, train_epochs, \
    num_train_samples, num_validation_samples, batch_size, conv_dropout, fc_dropout):
    '''
    :nvidia model from the paper:
        https://images.nvidia.com/content/tegra/automotive/images/2016/solutions/pdf/end-to-end-dl-using-px.pdf
    '''
    # Layer0 : normalized layer
    model.add(Lambda(lambda x: (x / 255.0) - 0.5, input_shape=(64, 64, 3)))
    # Layer1: Convolutional feature map 24@31x98
    model.add(Convolution2D(24, 5, 5, activation='relu', subsample=(2, 2)))
    model.add(Dropout(conv_dropout))
    # Layer2: Convolutional feature map 36@14x47
    model.add(Convolution2D(36, 5, 5, activation='relu', subsample=(2, 2)))
    model.add(Dropout(conv_dropout))
    # Layer3: Convolutional feature map 48@5x22
    model.add(Convolution2D(48, 5, 5, activation='relu', subsample=(2, 2)))
    model.add(Dropout(conv_dropout))
    # Layer4: Convolutional feature map 64@3x20
    model.add(Convolution2D(64, 3, 3, activation='relu'))
    model.add(Dropout(conv_dropout))
    # Layer5: Convolutional feature map 64@1x18
    model.add(Convolution2D(64, 3, 3, activation='relu'))
    model.add(Dropout(conv_dropout))
    # FC1
    model.add(Flatten())
    model.add(Dense(1164, activation='relu'))
    model.add(Dropout(fc_dropout))
    # FC2
    model.add(Dense(100, activation='relu'))
    model.add(Dropout(fc_dropout))
    # FC3
    model.add(Dense(50, activation='relu'))
    model.add(Dropout(fc_dropout))
    # FC4
    model.add(Dense(10, activation='relu'))
    # Outut layer
    model.add(Dense(1))

    model.summary()

    if optimizer == 'adam':
        '''
        Default adam optimizer 
        works to trace1  
        '''
        model.compile(loss='mse', optimizer='adam')

    if optimizer == 'sgd':
        '''
        SGD optimizer
        works to trace2 with 0.005 learning rate 
        '''
        sgd = optimizers.SGD(lr=0.005, decay=1e-6, momentum=0.9, nesterov=True)
        model.compile(loss='mse', optimizer=sgd)

    if optimizer == 'lookahead':
        '''
        new optimizer named lookahead
        source site: paper https://arxiv.org/abs/1907.08610, 
        code by keras https://github.com/bojone/keras_lookahead
        '''
        model.compile(optimizer=optimizers.Adam(1e-3),
                      loss='mse')  # Any optimizer
        lookahead = Lookahead(k=5, alpha=0.5)  # Initialize Lookahead
        lookahead.inject(model)  # add into model

    history_object = model.fit_generator(
        train_generator,
        steps_per_epoch=np.ceil(num_train_samples / batch_size),
        validation_data=validation_generator,
        validation_steps=np.ceil(num_validation_samples / batch_size),
        epochs=train_epochs,
        verbose=1)

    model.save('model.h5')
    print('nvidia-model-epoch{}-{}-{}.h5'.format(source_path, train_epochs,
                                                 optimizer))

    plot_loss(history_object)
コード例 #25
0
    def forward(self, preds, one_hot_target):
        preds = preds.log_softmax(dim=self.dim)
        return torch.mean(torch.sum(-one_hot_target * preds, dim=self.dim))


# loss_fn = nn.CrossEntropyLoss()
loss_fn = LabelSmoothingLoss()
device = torch.device('cuda:0')

#epochs = 50
epochs = 100
patience = 15

opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
opt = Lookahead(opt)
model = model.to(device)
rolling_loss = dict(train=RollingLoss(), valid=RollingLoss())
steps = dict(train=0, valid=0)

trials = 0
best_metric = -np.inf
history = []
stop = False

vis = Visdom(server='0.0.0.0',
             port=9091,
             username=os.environ['VISDOM_USERNAME'],
             password=os.environ['VISDOM_PASSWORD'])

# loaders = create_loaders(batch_size=7)
コード例 #26
0
class AgentDelayDDPG:
    def __init__(self, args, state_dim, action_dim, action_num):
        mod_dim = args.mod_dim
        self.update_counter = 0
        self.update_gap = args.update_gap
        self.tau = args.soft_update_tau
        self.policy_noise = args.policy_noise
        self.gamma = args.gamma

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.action_num = action_num
        self.state_idx = 1 + 1 + state_dim  # reward_dim==1, done_dim==1, state_dim
        self.action_idx = self.state_idx + action_dim * action_num

        from torch import optim
        from lookahead import Lookahead
        self.act = Actor(mod_dim, state_dim, action_dim,
                         action_num).to(self.device)
        # self.act_optimizer = optim.Adam(self.act.parameters(), lr=4e-4)
        self.act_optimizer = Lookahead(optim.Adam(self.act.parameters(),
                                                  lr=4e-4),
                                       k=5,
                                       alpha=0.5)
        self.act.train()

        self.act_target = Actor(mod_dim, state_dim, action_dim,
                                action_num).to(self.device)
        self.act_target.load_state_dict(self.act.state_dict())
        self.act_target.eval()

        self.cri = Critic(mod_dim, state_dim, action_dim).to(self.device)
        # self.cri_optimizer = optim.Adam(self.cri.parameters(), lr=1e-3)
        self.cri_optimizer = Lookahead(optim.Adam(self.cri.parameters(),
                                                  lr=1e-3),
                                       k=5,
                                       alpha=0.5)
        self.cri.train()

        self.cri_target = Critic(mod_dim, state_dim,
                                 action_dim).to(self.device)
        self.cri_target.load_state_dict(self.cri.state_dict())
        self.cri_target.eval()

        self.criterion = nn.SmoothL1Loss()

    def select_action(self, state):
        state = torch.tensor((state, ),
                             dtype=torch.float32,
                             device=self.device)
        action = self.act(state).cpu().data.numpy()
        return action[0]

    def soft_update(self, target, source):
        for target_param, param in zip(target.parameters(),
                                       source.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - self.tau) +
                                    param.data * self.tau)

    def update(self, memories, iter_num, batch_size):
        actor_loss_avg, critic_loss_avg = 0, 0

        k = 1 + memories.size / memories.memories_num
        iter_num = int(k * iter_num)
        batch_size = int(k * batch_size)

        for i in range(iter_num):
            with torch.no_grad():
                memory = memories.sample(batch_size)
                memory = torch.tensor(memory,
                                      dtype=torch.float32,
                                      device=self.device)
                reward = memory[:, 0:1]
                undone = memory[:, 1:2]
                state = memory[:, 2:self.state_idx]
                action = memory[:, self.state_idx:self.action_idx]
                next_state = memory[:, self.action_idx:]

                noise = torch.randn(action.size(),
                                    dtype=torch.float32,
                                    device=self.device) * self.policy_noise

            next_action = self.act_target(next_state) + noise
            next_action = next_action.clamp(-1.0, 1.0)

            with torch.no_grad():
                q_target = self.cri_target(next_state, next_action)
                q_target = reward + undone * self.gamma * q_target

            q_eval = self.cri(state, action)
            critic_loss = self.criterion(q_eval, q_target)
            critic_loss_avg += critic_loss.item()
            self.cri_optimizer.zero_grad()
            critic_loss.backward()
            self.cri_optimizer.step()

            actor_loss = -self.cri(state, self.act(state)).mean()
            actor_loss_avg += actor_loss.item()
            self.act_optimizer.zero_grad()
            actor_loss.backward()
            self.act_optimizer.step()

            self.update_counter += 1
            if self.update_counter == self.update_gap:
                self.update_counter = 0
                # self.act_target.load_state_dict(self.act.state_dict())
                # self.cri_target.load_state_dict(self.cri.state_dict())
                self.soft_update(self.act_target, self.act)
                self.soft_update(self.cri_target, self.cri)

        actor_loss_avg /= iter_num
        critic_loss_avg /= iter_num
        return actor_loss_avg, critic_loss_avg

    def save(self, mod_dir):
        torch.save(self.act.state_dict(), '%s/actor.pth' % (mod_dir, ))
        torch.save(self.act_target.state_dict(),
                   '%s/actor_target.pth' % (mod_dir, ))

        torch.save(self.cri.state_dict(), '%s/critic.pth' % (mod_dir, ))
        torch.save(self.cri_target.state_dict(),
                   '%s/critic_target.pth' % (mod_dir, ))
        print("Saved:", mod_dir)

    def load(self, mod_dir, load_actor_only=False):
        print("Loading:", mod_dir)
        self.act.load_state_dict(
            torch.load('%s/actor.pth' % (mod_dir, ),
                       map_location=lambda storage, loc: storage))
        self.act_target.load_state_dict(
            torch.load('%s/actor_target.pth' % (mod_dir, ),
                       map_location=lambda storage, loc: storage))

        if load_actor_only:
            print("load_actor_only!")
        else:
            self.cri.load_state_dict(
                torch.load('%s/critic.pth' % (mod_dir, ),
                           map_location=lambda storage, loc: storage))
            self.cri_target.load_state_dict(
                torch.load('%s/critic_target.pth' % (mod_dir, ),
                           map_location=lambda storage, loc: storage))
コード例 #27
0
if __name__ == '__main__':
    datas = Mydataset()
    imageDataloader = data.DataLoader(dataset=datas,
                                      batch_size=1,
                                      shuffle=True)

    net = MainNet(14).to(device)
    if os.path.exists(savepath):
        net.load_state_dict(torch.load(savepath))

    optim = torch.optim.Adam(net.parameters(), weight_decay=4e-4)

    base_opt = torch.optim.Adam(net.parameters(), lr=1e-3,
                                betas=(0.9, 0.999))  # Any optimizer
    lookahead = Lookahead(base_opt, k=5, alpha=0.5)  # Initialize Lookahead

    losses = []
    epoch = 0
    while True:
        for i, (label13, label26, label52, img) in enumerate(imageDataloader):
            output_13, output_26, output_52 = net(img.to(device))

            loss_13 = loss_fn(output_13, label13, 0.9)
            loss_26 = loss_fn(output_26, label26, 0.9)
            loss_52 = loss_fn(output_52, label52, 0.9)

            loss = loss_13 + loss_26 + loss_52
            losses.append(loss)
            lookahead.zero_grad()
            loss.backward()  # Self-defined loss function
コード例 #28
0
    elif isinstance(m, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(m.weight)


if __name__ == '__main__':
    weight_save_path = r"./params/arc_loss_test.pt"
    dataset = datasets.MNIST(root="../MyTest01/minist_torch/", train=True, transform=transforms.ToTensor(),
                             download=True)
    dataloader = DataLoader(dataset=dataset, batch_size=1024, shuffle=True, num_workers=5)
    cls_net = ClsNet()
    cls_net.cuda()
    if os.path.exists(weight_save_path):
        cls_net.load_state_dict(torch.load(weight_save_path))
    fig, ax = plt.subplots()
    optimizer = optim.Adam(cls_net.parameters())
    lookahead = Lookahead(optimizer)
    for epoch in range(100000):
        for i, (xs, ys) in enumerate(dataloader):
            xs = xs.cuda()
            ys = ys.cuda()
            coordinate, out = cls_net(xs, 1, 1)
            coordinate = coordinate.cpu().detach().numpy()
            loss = cls_net.get_loss(out, ys)
            # print([i for i, c in cls_net.named_parameters()])
            # exit()

            lookahead.zero_grad()
            loss.backward()
            lookahead.step()

            print(loss.cpu().detach().item())
コード例 #29
0
def run():
    d = {
        'image_id': os.listdir(config.TRAIN_IMAGE_PATH),
        'mask_id': os.listdir(config.TRAIN_MASK_PATH)
    }
    df = pd.DataFrame(data=d)

    folds = df.copy()

    kf = KFold(n_splits=config.N_FOLDS, shuffle=True, random_state=42)

    for fold, (train_idx, valid_idx) in enumerate(kf.split(folds)):

        print(f'FOLD: {fold+1}/{config.N_FOLDS}')

        train_test = folds.iloc[train_idx]
        valid_test = folds.iloc[valid_idx]

        train_test.reset_index(drop=True, inplace=True)
        valid_test.reset_index(drop=True, inplace=True)

        train_dataset = dataset.HuBMAPDataset(
            train_test, transforms=transforms.transforms_train)
        train_loader = DataLoader(train_dataset,
                                  batch_size=config.TRAIN_BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=config.NUM_WORKERS)

        valid_dataset = dataset.HuBMAPDataset(
            valid_test, transforms=transforms.transforms_valid)
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=config.VALID_BATCH_SIZE,
                                  shuffle=False,
                                  num_workers=config.NUM_WORKERS)

        loss_history = {"train": [], "valid": []}

        dice_history = {"train": [], "valid": []}

        jaccard_history = {"train": [], "valid": []}

        dice_max = 0.0
        kernel_type = 'unext50'
        best_file = f'../drive/MyDrive/{kernel_type}_best_fold{fold}_strong_aug_70_epochs.bin'

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        model = UneXt50().to(device)
        optimizer = Lookahead(RAdam(filter(lambda p: p.requires_grad,
                                           model.parameters()),
                                    lr=config.LR),
                              alpha=0.5,
                              k=5)
        # base_opt = optim.Adam(model.parameters(), lr=3e-4)
        # optimizer = SWA(base_opt)
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, config.N_EPOCHS)
        # scheduler = GradualWarmupSchedulerV2(optimizer, multiplier=10, total_epoch=config.WARMUP_EPO, after_scheduler=scheduler_cosine)
        loss_fn = metrics.symmetric_lovasz

        for epoch in range(config.N_EPOCHS):

            scheduler.step(epoch)
            avg_train_loss, train_dice_scores, train_jaccard_scores = engine.train_loop_fn(
                model, train_loader, optimizer, loss_fn,
                metrics.dice_coef_metric, metrics.jaccard_coef_metric, device)

            # if epoch > 10 and epoch % 5 == 0:
            #   optimizer.update_swa()

            loss_history["train"].append(avg_train_loss)
            dice_history["train"].append(train_dice_scores)
            jaccard_history["train"].append(train_jaccard_scores)

            avg_val_loss, val_dice_scores, val_jaccard_scores = engine.val_loop_fn(
                model, valid_loader, optimizer, loss_fn,
                metrics.dice_coef_metric, metrics.jaccard_coef_metric, device)

            loss_history["valid"].append(avg_val_loss)
            dice_history["valid"].append(val_dice_scores)
            jaccard_history["valid"].append(val_jaccard_scores)

            print(
                f"Epoch: {epoch+1} | lr: {optimizer.param_groups[0]['lr']:.7f} | train loss: {avg_train_loss:.4f} | val loss: {avg_val_loss:.4f}"
            )
            print(
                f"train dice: {train_dice_scores:.4f} | val dice: {val_dice_scores:.4f} | train jaccard: {train_jaccard_scores:.4f} | val jaccard: {val_jaccard_scores:.4f}"
            )

            if val_dice_scores > dice_max:
                print('score2 ({:.6f} --> {:.6f}).  Saving model ...'.format(
                    dice_max, val_dice_scores))
                torch.save(model.state_dict(), best_file)
                dice_max = val_dice_scores
コード例 #30
0
        optimizer = torch.optim.Adagrad(params, lr=args.lr, weight_decay=args.wdecay)
    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.wdecay)
    if args.optimizer == 'adamw':
        optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wdecay)
    if args.optimizer == 'lamb':
        from pytorch_lamb import Lamb
        optimizer = Lamb(params, lr=args.lr, weight_decay=args.wdecay, min_trust=0.25)
        #optimizer = Lamb(params, lr=args.lr, weight_decay=args.wdecay, min_trust=0.1)
        #optimizer = Lamb(params, lr=args.lr, weight_decay=args.wdecay, min_trust=0, random_min_trust=0.2, random_trust_dice=10)
        #optimizer = Lamb(params, lr=args.lr, weight_decay=args.wdecay, min_trust=0.2, random_min_trust=0.5, random_trust_dice=4)
    from lookahead import Lookahead
    if False:
        k, alpha = 5, 0.8
        print('Lookahead - k {} and alpha {}'.format(k, alpha))
        optimizer = Lookahead(base_optimizer=optimizer, k=k, alpha=alpha)

    from apex import amp
    model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    #model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

    for epoch in range(1, args.epochs+1):
        epoch_start_time = time.time()
        train(epoch - 1)
        if 't0' in optimizer.param_groups[0]:
            tmp = {}
            for prm in model.parameters():
                tmp[prm] = prm.data.clone()
                prm.data = optimizer.state[prm]['ax'].clone()

            val_loss2 = evaluate(val_data)
コード例 #31
0
def model_trainer(args):
    # Load MNIST
    data_root = './'
    train_set = MNIST(root=data_root, download=True, train=True, transform=ToTensor())
    train_sampler = DistributedSampler(train_set)
    same_seeds(args.seed_num)
    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=(train_sampler is None), pin_memory=True, sampler=train_sampler)
    valid_set = MNIST(root=data_root, download=True, train=False, transform=ToTensor())
    valid_loader = DataLoader(valid_set, batch_size=args.batch_size, shuffle=False, pin_memory=True)

    print(f'Now Training: {args.exp_name}')
    # Load model
    same_seeds(args.seed_num)
    model = Toy_Net()
    model = model.to(args.local_rank)

    # Model parameters
    os.makedirs(f'./experiment_model/', exist_ok=True)
    latest_model_path = f'./experiment_model/{args.exp_name}'
    optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, nesterov=True)
    lookahead = Lookahead(optimizer=optimizer, k=10, alpha=0.5)
    loss_function = nn.CrossEntropyLoss()
    if args.local_rank == 0:
        best_valid_acc = 0

    # Callbacks
    warm_up = lambda epoch: epoch / args.warmup_epochs if epoch <= args.warmup_epochs else 1
    scheduler_wu = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=warm_up)
    scheduler_re = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.1, patience=10, verbose=True)
    early_stopping = EarlyStopping(patience=50, verbose=True)
            
    # Apex
    #amp.register_float_function(torch, 'sigmoid')   # register for uncommonly function
    model, apex_optimizer = amp.initialize(model, optimizers=lookahead, opt_level="O1")

    # Build training model
    parallel_model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)

    # Train model
    if args.local_rank == 0:
        tb = SummaryWriter(f'./tensorboard_runs/{args.exp_name}')
    #apex_optimizer.zero_grad()
    #apex_optimizer.step()
    for epoch in range(args.epochs):
        epoch_start_time = time.time()
        train_loss, train_acc = 0., 0.
        valid_loss, valid_acc = 0., 0.
        train_num, valid_num = 0, 0
        train_sampler.set_epoch(epoch)

        # Train
        parallel_model.train()
        # Warm up
        #if epoch < args.warmup_epochs:
        #    scheduler_wu.step()
        for image, target in tqdm(train_loader, total=len(train_loader)):
            apex_optimizer.zero_grad()
            image = image.to(args.local_rank)
            target = target.to(args.local_rank, dtype=torch.long)
            outputs = parallel_model(image)
            predict = torch.argmax(outputs, dim=1)
            batch_loss = loss_function(outputs, target)
            batch_loss /= len(outputs)
            # Apex
            with amp.scale_loss(batch_loss, apex_optimizer) as scaled_loss:
                scaled_loss.backward()
            apex_optimizer.step()

            # Calculate loss & acc
            train_loss += batch_loss.item() * len(image)
            train_acc += (predict == target).sum().item()
            train_num += len(image)

        train_loss = train_loss / train_num
        train_acc = train_acc / train_num
        curr_lr = apex_optimizer.param_groups[0]['lr']
        if args.local_rank == 0:
            tb.add_scalar('LR', curr_lr, epoch)
            tb.add_scalar('Loss/train', train_loss, epoch)
            tb.add_scalar('Acc/train', train_acc, epoch)

        # Valid
        parallel_model.eval()
        with torch.no_grad():
            for image, target in tqdm(valid_loader, total=len(valid_loader)):
                image = image.to(args.local_rank)
                target = target.to(args.local_rank, dtype=torch.long)
                outputs = parallel_model(image)
                predict = torch.argmax(outputs, dim=1)
                batch_loss = loss_function(outputs, target)
                batch_loss /= len(outputs)
                    
                # Calculate loss & acc
                valid_loss += batch_loss.item() * len(image)
                valid_acc += (predict == target).sum().item()
                valid_num += len(image)

        valid_loss = valid_loss / valid_num
        valid_acc = valid_acc / valid_num
        if args.local_rank == 0:
            tb.add_scalar('Loss/valid', valid_loss, epoch)
            tb.add_scalar('Acc/valid', valid_acc, epoch)
            
        # Print result
        print(f'epoch: {epoch:03d}/{args.epochs}, time: {time.time()-epoch_start_time:.2f}s, learning_rate: {curr_lr}, train_loss: {train_loss:.4f}, train_acc: {train_acc:.4f}, valid_loss: {valid_loss:.4f}, valid_acc: {valid_acc:.4f}')

        # Learning_rate callbacks
        if epoch <= args.warmup_epochs:
            scheduler_wu.step()
        scheduler_re.step(valid_loss)
        early_stopping(valid_loss)
        if early_stopping.early_stop:
            break

        # Save_checkpoint
        if args.local_rank == 0:
            if valid_acc > best_valid_acc:
                best_valid_acc = valid_acc
                torch.save(parallel_model.module.state_dict(), f'{latest_model_path}.pt')

    if args.local_rank == 0:
        tb.close()