Example #1
0
def test_partitions():
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(a, b)
    model = GPipe(model, [1, 1], devices=['cpu', 'cpu'])

    assert isinstance(model.partitions, nn.ModuleList)
    assert isinstance(model.partitions[0], Partition)
    assert isinstance(model.partitions[1], Partition)

    assert 'partitions.0.module.0.weight' in model.state_dict()
Example #2
0
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 200 == 199:  # print every 200 mini-batch
            print('[epoch-%d, %5d samples] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 200))
            running_loss = 0.0
print("Finished Training")
endtime1 = time.time()
print("training time: %d" % (endtime1 - starttime))

torch.save(net.state_dict(), PATH)

# Test the network on the test data
dataiter = iter(testloader)
images, labels = dataiter.next()
images = images.to(device0)

# print images
# imshow(torchvision.utils.make_grid(images))
print('GroundTrue:', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

net.load_state_dict(torch.load(PATH))

outputs = net(images)
# to get highest energy for a class
_location, predicted = torch.max(outputs, 1)  # 1 means row maxing
Example #3
0
class GPipeModel(object):
    def __init__(self,
                 model_name,
                 model_path,
                 gradient_clip_value=5.0,
                 device_ids=None,
                 **kwargs):
        gpipe_model = nn.Sequential(gpipe_encoder(model_name, **kwargs),
                                    gpipe_decoder(model_name, **kwargs))
        self.model = GPipe(gpipe_model, balance=[1, 1], chunks=2)
        self.in_device = self.model.devices[0]
        self.out_device = self.model.devices[-1]
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.model_path, self.state = model_path, {}
        os.makedirs(os.path.split(self.model_path)[0], exist_ok=True)
        self.gradient_clip_value, self.gradient_norm_queue = gradient_clip_value, deque(
            [np.inf], maxlen=5)
        self.optimizer = None

    def train_step(self, train_x: torch.Tensor, train_y: torch.Tensor):
        self.optimizer.zero_grad()
        self.model.train()
        scores = self.model(train_x)
        loss = self.loss_fn(scores, train_y)
        loss.backward()
        self.clip_gradient()
        self.optimizer.step(closure=None)
        return loss.item()

    def predict_step(self, data_x: torch.Tensor, k: int):
        self.model.eval()
        with torch.no_grad():
            scores, labels = torch.topk(self.model(data_x), k)
            return torch.sigmoid(scores).cpu(), labels.cpu()

    def get_optimizer(self, **kwargs):
        self.optimizer = DenseSparseAdam(self.model.parameters(), **kwargs)

    def train(self,
              train_loader: DataLoader,
              valid_loader: DataLoader,
              opt_params: Optional[Mapping] = None,
              nb_epoch=100,
              step=100,
              k=5,
              early=100,
              verbose=True,
              swa_warmup=None,
              **kwargs):
        self.get_optimizer(**({} if opt_params is None else opt_params))
        global_step, best_n5, e = 0, 0.0, 0
        print_loss = 0.0  #
        for epoch_idx in range(nb_epoch):
            if epoch_idx == swa_warmup:
                self.swa_init()
            for i, (train_x, train_y) in enumerate(train_loader, 1):
                global_step += 1
                loss = self.train_step(
                    train_x.to(self.in_device, non_blocking=True),
                    train_y.to(self.out_device, non_blocking=True))
                print_loss += loss  #
                if global_step % step == 0:
                    self.swa_step()
                    self.swap_swa_params()
                    ##
                    labels = []
                    valid_loss = 0.0
                    self.model.eval()
                    with torch.no_grad():
                        for (valid_x, valid_y) in valid_loader:
                            logits = self.model(
                                valid_x.to(self.in_device, non_blocking=True))
                            valid_loss += self.loss_fn(
                                logits,
                                valid_y.to(self.out_device,
                                           non_blocking=True)).item()
                            scores, tmp = torch.topk(logits, k)
                            labels.append(tmp.cpu())
                    valid_loss /= len(valid_loader)
                    labels = np.concatenate(labels)
                    ##
                    #                    labels = np.concatenate([self.predict_step(valid_x, k)[1] for valid_x in valid_loader])
                    targets = valid_loader.dataset.data_y
                    p5, n5 = get_p_5(labels, targets), get_n_5(labels, targets)
                    if n5 > best_n5:
                        self.save_model(epoch_idx > 3 * swa_warmup)
                        best_n5, e = n5, 0
                    else:
                        e += 1
                        if early is not None and e > early:
                            return
                    self.swap_swa_params()
                    if verbose:
                        log_msg = '%d %d train loss: %.7f valid loss: %.7f P@5: %.5f N@5: %.5f early stop: %d' % \
                        (epoch_idx, i * train_loader.batch_size, print_loss / step, valid_loss, round(p5, 5), round(n5, 5), e)
                        logger.info(log_msg)
                        print_loss = 0.0

    def predict(self,
                data_loader: DataLoader,
                k=100,
                desc='Predict',
                **kwargs):
        self.load_model()
        scores_list, labels_list = zip(*(
            self.predict_step(data_x.to(self.in_device, non_blocking=True), k)
            for data_x in tqdm(data_loader, desc=desc, leave=False)))
        return np.concatenate(scores_list), np.concatenate(labels_list)

    def save_model(self, last_epoch):
        if not last_epoch: return
        for trial in range(5):
            try:
                torch.save(self.model.state_dict(), self.model_path)
                break
            except:
                print('saving failed')

    def load_model(self):
        self.model.load_state_dict(torch.load(self.model_path))

    def clip_gradient(self):
        if self.gradient_clip_value is not None:
            max_norm = max(self.gradient_norm_queue)
            total_norm = torch.nn.utils.clip_grad_norm_(
                self.model.parameters(), max_norm * self.gradient_clip_value)
            self.gradient_norm_queue.append(
                min(total_norm, max_norm * 2.0, 1.0))
            if total_norm > max_norm * self.gradient_clip_value:
                logger.warn(
                    F'Clipping gradients with total norm {round(total_norm, 5)} '
                    F'and max norm {round(max_norm, 5)}')

    def swa_init(self):
        if 'swa' not in self.state:
            logger.info('SWA Initializing')
            swa_state = self.state['swa'] = {'models_num': 1}
            for n, p in self.model.named_parameters():
                swa_state[n] = p.data.cpu().detach()

    def swa_step(self):
        if 'swa' in self.state:
            swa_state = self.state['swa']
            swa_state['models_num'] += 1
            beta = 1.0 / swa_state['models_num']
            with torch.no_grad():
                for n, p in self.model.named_parameters():
                    swa_state[n].mul_(1.0 - beta).add_(beta, p.data.cpu())

    def swap_swa_params(self):
        if 'swa' in self.state:
            swa_state = self.state['swa']
            for n, p in self.model.named_parameters():
                gpu_id = p.get_device()
                p.data, swa_state[n] = swa_state[n], p.data.cpu()  #
                p.data = p.data.cuda(gpu_id)

    def disable_swa(self):
        if 'swa' in self.state:
            del self.state['swa']