Ejemplo n.º 1
0
    def __call__(
        self, net: nn.Module, input_names: List[str], data_loaders
    ) -> None:
        optimizer = torch.optim.Adagrad(
            net.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
        )

        writer = SummaryWriter(self.tensorboard_path)

        timer = Timer()

        training_iter = iter(data_loaders["training_data_loader"])
        full_batch_iter = iter(data_loaders["full_batch_loader"])

        avg_epoch_grad = 0.0
        for epoch_no in range(self.epochs):
            if self.decreasing_step_size:
                for param_group in optimizer.param_groups:
                    param_group["lr"] *= 1 / math.sqrt(epoch_no + 1)
            for batch_no in range(self.num_batches_per_epoch):
                with timer("gradient oracle"):
                    data_entry = next(training_iter)
                    optimizer.zero_grad()
                    inputs = [
                        data_entry[k].to(self.device) for k in input_names
                    ]
                    loss = self.inference(net, inputs)
                    loss.backward()
                    optimizer.step()

            # compute the gradient norm and loss over training set
            avg_epoch_loss = 0.0
            full_batch_iter = iter(data_loaders["full_batch_loader"])
            net.zero_grad()
            for i, data_entry in enumerate(full_batch_iter):
                inputs = [data_entry[k].to(self.device) for k in input_names]
                loss = self.inference(net, inputs)
                loss.backward()
                avg_epoch_loss += loss.item()
            avg_epoch_loss /= i + 1
            epoch_grad = 0.0
            for p in net.parameters():
                if p.grad is None:
                    continue
                epoch_grad += torch.norm(p.grad.data / (i + 1)).item()
            net.zero_grad()

            # compute the validation loss
            validation_loss = None
            if self.eval_model and epoch_no % self.validation_freq == 0:
                validation_iter = iter(data_loaders["validation_data_loader"])
                validation_loss = 0.0
                with torch.no_grad():
                    for i, data_entry in enumerate(validation_iter):
                        net.zero_grad()
                        inputs = [
                            data_entry[k].to(self.device) for k in input_names
                        ]
                        loss = self.inference(net, inputs)
                        validation_loss += loss.item()
                    validation_loss /= i + 1
            num_iters = (
                self.num_batches_per_epoch * (epoch_no + 1) * self.batch_size
            )
            avg_epoch_grad = (avg_epoch_grad * epoch_no + epoch_grad) / (
                epoch_no + 1
            )
            time_in_ms = timer.totals["gradient oracle"] * 1000
            writer.add_scalar(
                "gradnorm/iters",
                avg_epoch_grad,
                (epoch_no + 1) * self.num_batches_per_epoch,
            )
            writer.add_scalar("gradnorm/grads", avg_epoch_grad, num_iters)
            writer.add_scalar("gradnorm/time", avg_epoch_grad, time_in_ms)
            writer.add_scalar(
                "train_loss/iters",
                avg_epoch_loss,
                (epoch_no + 1) * self.num_batches_per_epoch,
            )
            writer.add_scalar("train_loss/grads", avg_epoch_loss, num_iters)
            writer.add_scalar("train_loss/time", avg_epoch_loss, time_in_ms)
            if self.eval_model and epoch_no % self.validation_freq == 0:
                writer.add_scalar(
                    "val_loss/iters",
                    validation_loss,
                    (epoch_no + 1) * self.num_batches_per_epoch,
                )
                writer.add_scalar("val_loss/grads", validation_loss, num_iters)
                writer.add_scalar("val_loss/time", validation_loss, time_in_ms)
                print(
                    "\nTraining Loss: {:.4f}, Test Loss: {:.4f}\n".format(
                        avg_epoch_loss, validation_loss
                    )
                )
            else:
                print(f"\nTraining Loss: {avg_epoch_loss:.4f} \n")
            print("Epoch ", epoch_no, " is done!")

        writer.close()
        print(
            "task: "
            + self.task_name
            + " on Adagrad with lr="
            + str(self.learning_rate)
            + " is done!"
        )
Ejemplo n.º 2
0
    def __call__(self, net: nn.Module, input_names: List[str],
                 data_loaders) -> None:
        optimizer = torch.optim.Adam(
            net.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
        )

        writer = SummaryWriter(self.tensorboard_path)

        timer = Timer()

        training_iter = iter(data_loaders["training_data_loader"])
        anchor_iter = iter(data_loaders["anchor_data_loader"])

        group_ratio = (data_loaders["group_ratio"]
                       if self.weighted_batch else None)
        avg_epoch_grad = 0.0
        v_0_norm = 0.0
        v_t_norm = 0.0
        for epoch_no in range(self.epochs):
            if self.decreasing_step_size:
                for param_group in optimizer.param_groups:
                    param_group["lr"] *= 1 / math.sqrt(epoch_no + 1)
            for batch_no in range(self.num_batches_per_epoch):
                iter_n = epoch_no * self.num_batches_per_epoch + batch_no
                if (iter_n == 0 or v_t_norm <= self.gamma * v_0_norm
                        or iter_n % self.freq == 0):
                    anchor_model = copy.deepcopy(net)
                    sg_model = copy.deepcopy(net)
                    anchor_model.zero_grad()
                    with timer("gradient oracle"):
                        data_entry = next(anchor_iter)
                        inputs = [
                            data_entry[k].to(self.device) for k in input_names
                        ]
                        loss = self.inference(
                            anchor_model,
                            inputs,
                            weighted_batch=self.weighted_batch,
                            group_ratio=group_ratio,
                        )
                        loss.backward()
                    for p in anchor_model.parameters():
                        if p.grad is None:
                            continue
                        v_0_norm += torch.norm(p.grad.data)**2

                v_t_norm = 0.0
                data_entry = next(training_iter)
                optimizer.zero_grad()
                with timer("gradient oracle"):
                    inputs = [
                        data_entry[k].to(self.device) for k in input_names
                    ]
                    inputs_ = copy.deepcopy(inputs)
                    net.zero_grad()
                    sg_model.zero_grad()
                    loss = self.inference(sg_model, inputs)
                    loss.backward()

                loss = self.inference(net, inputs_)
                loss.backward()
                with timer("gradient oracle"):
                    for p1, p2, p3 in zip(
                            net.parameters(),
                            sg_model.parameters(),
                            anchor_model.parameters(),
                    ):
                        if (p1.grad is None or p2.grad is None
                                or p3.grad is None):
                            continue
                        v_t = torch.zeros_like(p1.grad.data, device=p1.device)
                        v_t.add_(p1.grad.data - p2.grad.data + p3.grad.data)
                        p1.grad.data.zero_().add_(v_t)
                        v_t_norm += torch.norm(v_t)**2
                    optimizer.step()

            # compute the gradient norm and loss over training set
            avg_epoch_loss = 0.0
            full_batch_iter = iter(data_loaders["full_batch_loader"])
            net.zero_grad()
            for i, data_entry in enumerate(full_batch_iter):
                inputs = [data_entry[k].to(self.device) for k in input_names]
                loss = self.inference(net, inputs)
                loss.backward()
                avg_epoch_loss += loss.item()
            avg_epoch_loss /= i + 1
            epoch_grad = 0.0
            for p in net.parameters():
                if p.grad is None:
                    continue
                epoch_grad += torch.norm(p.grad.data / (i + 1))
            net.zero_grad()

            # compute the validation loss
            if self.eval_model and epoch_no % self.validation_freq == 0:
                net_validate = copy.deepcopy(net)
                validation_iter = iter(data_loaders["validation_data_loader"])
                validation_loss = 0.0
                with torch.no_grad():
                    for i, data_entry in enumerate(validation_iter):
                        net_validate.zero_grad()
                        inputs = [
                            data_entry[k].to(self.device) for k in input_names
                        ]
                        loss = self.inference(net_validate, inputs)
                        validation_loss += loss.item()
                validation_loss /= i + 1

            num_iters = (
                self.num_batches_per_epoch *
                (epoch_no + 1) * 2 * self.batch_size +
                self.num_batches_per_epoch / self.freq * self.num_strata *
                (epoch_no + 1) * self.batch_size)
            avg_epoch_grad = (avg_epoch_grad * epoch_no +
                              epoch_grad) / (epoch_no + 1)
            time_in_ms = timer.totals["gradient oracle"] * 1000
            writer.add_scalar(
                "gradnorm/iters",
                avg_epoch_grad,
                (epoch_no + 1) * self.num_batches_per_epoch,
            )
            writer.add_scalar("gradnorm/grads", avg_epoch_grad, num_iters)
            writer.add_scalar("gradnorm/time", avg_epoch_grad, time_in_ms)
            writer.add_scalar(
                "train_loss/iters",
                avg_epoch_loss,
                (epoch_no + 1) * self.num_batches_per_epoch,
            )
            writer.add_scalar("train_loss/grads", avg_epoch_loss, num_iters)
            writer.add_scalar("train_loss/time", avg_epoch_loss, time_in_ms)
            if self.eval_model and epoch_no % self.validation_freq == 0:
                writer.add_scalar(
                    "val_loss/iters",
                    validation_loss,
                    (epoch_no + 1) * self.num_batches_per_epoch,
                )
                writer.add_scalar("val_loss/grads", validation_loss, num_iters)
                writer.add_scalar("val_loss/time", validation_loss, time_in_ms)
                print("\nTraining Loss: {:.4f}, Test Loss: {:.4f}\n".format(
                    avg_epoch_loss, validation_loss))
            else:
                print("\nTraining Loss: {:.4f} \n".format(avg_epoch_loss))
            print("Epoch ", epoch_no, " is done!")

        writer.close()
        print("task: " + self.task_name + " on SAdam with lr=" +
              str(self.learning_rate) + " is done!")