Beispiel #1
0
 def __init__(self, config, net=None):
     self.log_dir = config.log_dir
     self.model_dir = config.model_dir
     self.net = net
     self.best_val_loss = np.inf
     self.tb_writer = SummaryWriter(log_dir=self.log_dir)
     self.clock = TrainClock()
Beispiel #2
0
    def __init__(self, config):
        self.model_dir = config.model_dir
        self.clock = TrainClock()
        self.batch_size = config.batch_size

        # build network
        self.net = self.build_net(config)

        # set loss function
        self.set_loss_function()

        # set optimizer
        self.set_optimizer(config)
    def __init__(self, config):
        self.log_dir = config.log_dir
        self.model_dir = config.model_dir
        self.clock = TrainClock()
        self.device = config.device
        self.batch_size = config.batch_size

        # build network
        self.net = self.build_net()

        # set loss function
        self.set_loss_function()

        # set optimizer
        self.set_optimizer(config)

        # set tensorboard writer
        self.train_tb = SummaryWriter(os.path.join(self.log_dir, 'train.events'))
        self.val_tb = SummaryWriter(os.path.join(self.log_dir, 'val.events'))
Beispiel #4
0
    def __init__(self, train_spec, net=None):
        # setproctitle(config.exp_name)

        self.log_dir = train_spec.log_dir
        ensure_dir(self.log_dir)
        # logconf.set_output_file(os.path.join(self.log_dir, 'log.txt'))
        self.model_dir = train_spec.log_model_dir
        ensure_dir(self.model_dir)

        self.net = net
        self.clock = TrainClock()
Beispiel #5
0
    def __init__(self, config):
        self.log_dir = config.log_dir
        self.model_dir = config.model_dir
        self.clock = TrainClock()
        self.batch_size = config.batch_size

        # build network
        self.net = self.build_net(config)
        # print('-----network architecture-----')
        # print(self.net)

        # set loss function
        self.set_loss_function()

        # set optimizer and scheduler
        self.set_optimizer(config)
        self.set_scheduler(config)

        # set tensorboard writer
        self.train_tb = SummaryWriter(os.path.join(self.log_dir, 'train.events'))
        self.val_tb = SummaryWriter(os.path.join(self.log_dir, 'val.events'))
Beispiel #6
0
class Session:
    def __init__(self, config, net=None):
        self.log_dir = config.log_dir
        self.model_dir = config.model_dir
        self.net = net
        self.best_val_loss = np.inf
        self.tb_writer = SummaryWriter(log_dir=self.log_dir)
        self.clock = TrainClock()

    def save_checkpoint(self, name):
        ckp_path = os.path.join(self.model_dir, name)
        tmp = {
            'state_dict': self.net.state_dict(),
            'best_val_loss': self.best_val_loss,
            'clock': self.clock.make_checkpoint(),
        }
        torch.save(tmp, ckp_path)

    def load_checkpoint(self, ckp_path):
        checkpoint = torch.load(ckp_path)
        self.net.load_state_dict(checkpoint['state_dict'])
        self.clock.restore_checkpoint(checkpoint['clock'])
        self.best_val_loss = checkpoint['best_val_loss']
Beispiel #7
0
class Session:
    def __init__(self, config, net=None):
        self.log_dir = config.log_dir
        self.model_dir = config.model_dir
        self.net = net
        self.best_val_acc = 0.0
        self.tb_writer = SummaryWriter(log_dir=self.log_dir)
        self.clock = TrainClock()

    def save_checkpoint(self, name):
        ckp_path = os.path.join(self.model_dir, name)
        tmp = {
            'net': self.net,
            'best_val_acc': self.best_val_acc,
            'clock': self.clock.make_checkpoint(),
        }
        torch.save(tmp, ckp_path)

    def load_checkpoint(self, ckp_path):
        checkpoint = torch.load(ckp_path)
        self.net = checkpoint['net']
        self.clock.restore_checkpoint(checkpoint['clock'])
        self.best_val_acc = checkpoint['best_val_acc']
Beispiel #8
0
    parser = argparse.ArgumentParser()
    parser.add_argument('--val', action='store_true', help='weither to run validation')
    parser.add_argument('--bench', action='store_true',
                        help = 'weither to generate results on benchmark dataset')
    parser.add_argument('--smooth', action='store_true',
                        help = 'weither to generator smootGrad results on benchmark dataset')
    parser.add_argument('--gen', action = 'store_true',
                        help = 'weither to generate Large eps adversarial examples on benchmark data')

    parser.add_argument('--resume', type=str, default=None,
                        help='checkpoint path')

    args = parser.parse_args()

    clock = TrainClock()
    clock.epoch = 21
    net = ant_model()
    net.cuda()
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            check_point = torch.load(args.resume)
            net.load_state_dict(check_point['state_dict'])

            print('Modeled loaded from {} with metrics:'.format(args.resume))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

        base_path = os.path.split(args.resume)[0]
    else:
class BaseAgent(object):
    """Base trainer that provides commom training behavior. 
        All trainer should be subclass of this class.
    """
    def __init__(self, config):
        self.log_dir = config.log_dir
        self.model_dir = config.model_dir
        self.clock = TrainClock()
        self.device = config.device
        self.batch_size = config.batch_size

        # build network
        self.net = self.build_net()

        # set loss function
        self.set_loss_function()

        # set optimizer
        self.set_optimizer(config)

        # set tensorboard writer
        self.train_tb = SummaryWriter(os.path.join(self.log_dir, 'train.events'))
        self.val_tb = SummaryWriter(os.path.join(self.log_dir, 'val.events'))

    @abstractmethod
    def build_net(self):
        raise NotImplementedError

    def set_loss_function(self):
        """set loss function used in training"""
        self.criterion = nn.MSELoss().to(self.device)

    def set_optimizer(self, config):
        """set optimizer and lr scheduler used in training"""
        self.optimizer = optim.Adam(self.net.parameters(), config.lr)
        self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, config.lr_decay)

    def save_ckpt(self, name=None):
        """save checkpoint during training for future restore"""
        if name is None:
            save_path = os.path.join(self.model_dir, "ckpt_epoch{}.pth".format(self.clock.epoch))
            print("Checkpoint saved at {}".format(save_path))
        else:
            save_path = os.path.join(self.model_dir, "{}.pth".format(name))
        if isinstance(self.net, nn.DataParallel):
            torch.save({
                'clock': self.clock.make_checkpoint(),
                'model_state_dict': self.net.module.cpu().state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'scheduler_state_dict': self.scheduler.state_dict(),
            }, save_path)
        else:
            torch.save({
                'clock': self.clock.make_checkpoint(),
                'model_state_dict': self.net.cpu().state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'scheduler_state_dict': self.scheduler.state_dict(),
            }, save_path)
        self.net.cuda()

    def load_ckpt(self, name=None):
        """load checkpoint from saved checkpoint"""
        name = name if name == 'latest' else "ckpt_epoch{}".format(name)
        load_path = os.path.join(self.model_dir, "{}.pth".format(name))
        if not os.path.exists(load_path):
            raise ValueError("Checkpoint {} not exists.".format(load_path))

        checkpoint = torch.load(load_path)
        print("Checkpoint loaded from {}".format(load_path))
        if isinstance(self.net, nn.DataParallel):
            self.net.module.load_state_dict(checkpoint['model_state_dict'])
        else:
            self.net.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.clock.restore_checkpoint(checkpoint['clock'])

    @abstractmethod
    def forward(self, data):
        pass

    def update_network(self, loss_dict):
        """update network by back propagation"""
        loss = sum(loss_dict.values())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_learning_rate(self):
        """record and update learning rate"""
        self.train_tb.add_scalar('learning_rate', self.optimizer.param_groups[-1]['lr'], self.clock.epoch)
        self.scheduler.step(self.clock.epoch)

    def record_losses(self, loss_dict, mode='train'):
        losses_values = {k: v.item() for k, v in loss_dict.items()}

        # record loss to tensorboard
        tb = self.train_tb if mode == 'train' else self.val_tb
        for k, v in losses_values.items():
            tb.add_scalar(k, v, self.clock.step)

    def train_func(self, data):
        """one step of training"""
        self.net.train()

        outputs, losses = self.forward(data)

        self.update_network(losses)
        self.record_losses(losses, 'train')

        return outputs, losses

    def val_func(self, data):
        """one step of validation"""
        self.net.eval()

        with torch.no_grad():
            outputs, losses = self.forward(data)

        self.record_losses(losses, 'validation')

        return outputs, losses

    def visualize_batch(self, data, mode, **kwargs):
        """write visualization results to tensorboard writer"""
        raise NotImplementedError
Beispiel #10
0
class VGGAgent(object):
    """Base trainer that provides commom training behavior.
        All trainer should be subclass of this class.
    """
    def __init__(self, config):
        self.model_dir = config.model_dir
        self.clock = TrainClock()
        self.batch_size = config.batch_size

        # build network
        self.net = self.build_net(config)

        # set loss function
        self.set_loss_function()

        # set optimizer
        self.set_optimizer(config)

    def build_net(self, config):
        return get_network("VGG", config).cuda()

    def set_loss_function(self):
        """set loss function used in training"""
        self.criterion = nn.CrossEntropyLoss().cuda()

    def set_optimizer(self, config):
        """set optimizer and lr scheduler used in training"""
        self.optimizer = optim.Adam(self.net.parameters(), config.lr)
        self.scheduler = optim.lr_scheduler.ExponentialLR(
            self.optimizer, config.lr_decay)

    def save_ckpt(self, name=None):
        """save checkpoint during training for future restore"""
        if name is None:
            save_path = os.path.join(
                self.model_dir, "ckpt_epoch{}.pth".format(self.clock.epoch))
            print("Checkpoint saved at {}".format(save_path))
        else:
            save_path = os.path.join(self.model_dir, "{}.pth".format(name))
        if isinstance(self.net, nn.DataParallel):
            torch.save(
                {
                    'clock': self.clock.make_checkpoint(),
                    'model_state_dict': self.net.module.cpu().state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                }, save_path)
        else:
            torch.save(
                {
                    'clock': self.clock.make_checkpoint(),
                    'model_state_dict': self.net.cpu().state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                }, save_path)
        self.net.cuda()

    def load_ckpt(self, name=None):
        """load checkpoint from saved checkpoint"""
        name = name if name == 'latest' else "ckpt_epoch{}".format(name)
        load_path = os.path.join(self.model_dir, "{}.pth".format(name))
        if not os.path.exists(load_path):
            raise ValueError("Checkpoint {} not exists.".format(load_path))

        checkpoint = torch.load(load_path)
        print("Checkpoint loaded from {}".format(load_path))
        if isinstance(self.net, nn.DataParallel):
            self.net.module.load_state_dict(checkpoint['model_state_dict'])
        else:
            self.net.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.clock.restore_checkpoint(checkpoint['clock'])

    def forward(self, data, label):
        output = self.net(data)
        losses = self.criterion(output, label)

        return output, {"loss": losses}

    def update_network(self, loss_dict):
        """update network by back propagation"""
        loss = sum(loss_dict.values())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_learning_rate(self):
        """record and update learning rate"""
        self.scheduler.step(self.clock.epoch)

    def train_func(self, data, label):
        """one step of training"""
        self.net.train()

        outputs, losses = self.forward(data, label)

        self.update_network(losses)

        return outputs, losses

    def val_func(self, data, label):
        """one step of validation"""
        self.net.eval()

        with torch.no_grad():
            outputs, losses = self.forward(data, label)

        return outputs, losses
Beispiel #11
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c',
                        '--continue',
                        dest='continue_path',
                        type=str,
                        required=False)
    parser.add_argument('-g', '--gpu_ids', type=int, default=0, required=False)
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_ids)
    config.device = torch.device(
        "cuda:0" if torch.cuda.is_available() else "cpu")

    config.isTrain = True

    if not os.path.exists('train_log'):
        os.symlink(config.exp_dir, 'train_log')

    # get dataset
    train_loader = get_dataloader("train", batch_size=config.batch_size)
    val_loader = get_dataloader("test", batch_size=config.batch_size)
    val_cycle = cycle(val_loader)
    dataset_size = len(train_loader)
    print('The number of training motions = %d' %
          (dataset_size * config.batch_size))

    # create tensorboard writer
    train_tb = SummaryWriter(os.path.join(config.log_dir, 'train.events'))
    val_tb = SummaryWriter(os.path.join(config.log_dir, 'val.events'))

    # get model
    net = CycleGANModel(config)
    net.print_networks(True)

    # start training
    clock = TrainClock()
    net.train()

    for e in range(config.nr_epochs):
        # begin iteration
        pbar = tqdm(train_loader)
        for b, data in enumerate(pbar):
            net.train()
            net.set_input(
                data)  # unpack data from dataset and apply preprocessing
            net.optimize_parameters(
            )  # calculate loss functions, get gradients, update network weights

            # get loss
            losses_values = net.get_current_losses()

            # update tensorboard
            train_tb.add_scalars('train_loss',
                                 losses_values,
                                 global_step=clock.step)

            # visualize
            if clock.step % config.visualize_frequency == 0:
                motion_dict = net.infer()
                for k, v in motion_dict.items():
                    phase = 'h' if k[-1] == 'A' else 'nh'
                    motion3d = train_loader.dataset.preprocess_inv(
                        v.detach().cpu().numpy()[0], phase)
                    img = plot_motion(motion3d, phase)
                    train_tb.add_image(k, img, global_step=clock.step)

            pbar.set_description("EPOCH[{}][{}/{}]".format(
                e, b, len(train_loader)))
            pbar.set_postfix(OrderedDict(losses_values))

            # validation
            if clock.step % config.val_frequency == 0:
                net.eval()
                data = next(val_cycle)
                net.set_input(data)
                net.forward()

                losses_values = net.get_current_losses()
                val_tb.add_scalars('val_loss',
                                   losses_values,
                                   global_step=clock.step)

                # visualize
                if clock.step % config.visualize_frequency == 0:
                    motion_dict = net.infer()
                    for k, v in motion_dict.items():
                        phase = 'h' if k[-1] == 'A' else 'nh'
                        motion3d = val_loader.dataset.preprocess_inv(
                            v.detach().cpu().numpy()[0], phase)
                        img = plot_motion(motion3d, phase)
                        val_tb.add_image(k, img, global_step=clock.step)

            clock.tick()

        # leraning_rate to tensorboarrd
        lr = net.optimizers[0].param_groups[0]['lr']
        train_tb.add_scalar("learning_rate", lr, global_step=clock.step)

        if clock.epoch % config.save_frequency == 0:
            net.save_networks(epoch=e)

        clock.tock()
        net.update_learning_rate(
        )  # update learning rates at the end of every epoch.
Beispiel #12
0
DEVICE = torch.device('cuda:{}'.format(args.d))
if args.exp is None:
    cur_dir = os.path.realpath('./')
    args.exp = cur_dir.split(os.path.sep)[-1]
log_dir = os.path.join('../../logs', args.exp)
exp_dir = os.path.join('../../exps', args.exp)
train_res_path = os.path.join(exp_dir, 'train_results.txt')
val_res_path = os.path.join(exp_dir, 'val_results.txt')
final_res_path = os.path.join(exp_dir, 'final_results.txt')
if not os.path.exists(exp_dir):
    os.mkdir(exp_dir)

save_args(args, exp_dir)
writer = SummaryWriter(log_dir)

clock = TrainClock()

learning_rate_policy = [[5, 0.01], [3, 0.001], [2, 0.0001]]
get_learing_rate = MultiStageLearningRatePolicy(learning_rate_policy)


def adjust_learning_rate(optimizer, epoch):
    #global get_lea
    lr = get_learing_rate(epoch)
    for param_group in optimizer.param_groups:

        param_group['lr'] = lr


torch.backends.cudnn.benchmark = True
ds_train = create_train_dataset(args.batch_size)
Beispiel #13
0
class BaseAgent(object):
    """Base trainer that provides commom training behavior. 
        All trainer should be subclass of this class.
    """
    def __init__(self, config):
        self.log_dir = config.log_dir
        self.model_dir = config.model_dir
        self.clock = TrainClock()
        self.device = config.device
        self.batch_size = config.batch_size

        # build network
        self.net = self.build_net(config).cuda()

        # set loss function
        self.set_loss_function()

        # set optimizer
        self.set_optimizer(config)

    @abstractmethod
    def build_net(self, config):
        raise NotImplementedError

    def set_loss_function(self):
        """set loss function used in training"""
        self.criterion = nn.MSELoss().cuda()

    def set_optimizer(self, config):
        """set optimizer and lr scheduler used in training"""
        self.optimizer = optim.Adam(self.net.parameters(), config.lr)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,
                                                   config.lr_step_size)

    def save_ckpt(self, name=None):
        """save checkpoint during training for future restore"""
        if name is None:
            save_path = os.path.join(
                self.model_dir, "ckpt_epoch{}.pth".format(self.clock.epoch))
        else:
            save_path = os.path.join(self.model_dir, name)
        if isinstance(self.net, nn.DataParallel):
            torch.save(
                {
                    'clock': self.clock.make_checkpoint(),
                    'model_state_dict': self.net.module.cpu().state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                }, save_path)
        else:
            torch.save(
                {
                    'clock': self.clock.make_checkpoint(),
                    'model_state_dict': self.net.cpu().state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                }, save_path)
        self.net.cuda()

    def load_ckpt(self, path=None):
        """load checkpoint from saved checkpoint"""
        if path is not None:
            load_path = path
        else:
            load_path = os.path.join(self.model_dir, "latest.pth.tar")
        checkpoint = torch.load(load_path)
        if isinstance(self.net, nn.DataParallel):
            self.net.module.load_state_dict(checkpoint['model_state_dict'])
        else:
            self.net.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.clock.restore_checkpoint(checkpoint['clock'])

    @abstractmethod
    def forward(self, data):
        pass

    def update_network(self, loss_dict):
        """update network by back propagation"""
        loss = sum(loss_dict.values())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_learning_rate(self):
        self.scheduler.step(self.clock.epoch)

    def train_func(self, data):
        """one step of training"""
        self.net.train()

        outputs, losses = self.forward(data)

        self.update_network(losses)

        return outputs, losses

    def val_func(self, data):
        """one step of validation"""
        self.net.eval()

        with torch.no_grad():
            outputs, losses = self.forward(data)

        return outputs, losses

    def visualize_batch(self, data, tb, **kwargs):
        """write visualization results to tensorboard writer"""
        raise NotImplementedError