def set_model(args, cfg, checkpoint): # model if checkpoint: model = Classifier(pretrained=False) model.load_state_dict(checkpoint['model']) else: model = Classifier(pretrained=True) if args.data_parallel: model = DataParallel(model) model = model.to(device=args.device) # optimizer if cfg['optimizer'] == 'sgd': optimizer = optim.ASGD(model.parameters(), lr=cfg['learning_rate'], weight_decay=cfg['weight_decay']) elif cfg['optimizer'] == 'adam': optimizer = optim.Adam(model.parameters(), lr=cfg['learning_rate'], weight_decay=cfg['weight_decay']) elif cfg['optimizer'] == 'adabound': optimizer = AdaBound(model.parameters(), lr=cfg['learning_rate'], final_lr=0.1, weight_decay=cfg['weight_decay']) elif cfg['optimizer'] == 'amsbound': optimizer = AdaBound(model.parameters(), lr=cfg['learning_rate'], final_lr=0.1, weight_decay=cfg['weight_decay'], amsbound=True) # checkpoint if checkpoint and args.load_optimizer: optimizer.load_state_dict(checkpoint['optimizer']) return model, optimizer
assert opt.model == checkpoint['model'] if opt.pretrained: model_dict = model.state_dict() passed_dict = ['conv9.weight','conv10.weight','conv11.weight'] new_state_dict = OrderedDict() new_state_dict = {k: v for k,v in checkpoint['state_dict'].items() if k not in passed_dict} model_dict.update(new_state_dict) model.load_state_dict(model_dict) else: model.load_state_dict(checkpoint['state_dict']) opt.begin_epoch = checkpoint['epoch'] model = model.to(opt.device) if not opt.no_train and not opt.pretrained: optimizer.load_state_dict(checkpoint['optimizer']) best_mAP = checkpoint["best_mAP"] ######################################## # Train, Val, Test # ######################################## if opt.test: test(model,test_dataloader,opt.begin_epoch,opt) else: for epoch in range(opt.begin_epoch, opt.num_epochs + 1): if not opt.no_train: print("\n---- Training Model ----") train(model,optimizer,train_dataloader,epoch,opt,train_logger, best_mAP=best_mAP) if not opt.no_val and (epoch+1) % opt.val_interval == 0:
class TrainNetwork(object): """The main train network""" def __init__(self, args): super(TrainNetwork, self).__init__() self.args = args self.dur_time = 0 self.logger = self._init_log() if not torch.cuda.is_available(): self.logger.info('no gpu device available') sys.exit(1) self._init_hyperparam() self._init_random_and_device() self._init_model() def _init_hyperparam(self): if 'cifar100' == self.args.train_dataset: # cifar10: 6000 images per class, 10 classes, 50000 training images and 10000 test images # cifar100: 600 images per class, 100 classes, 500 training images and 100 testing images per class self.args.num_classes = 100 self.args.layers = 20 self.args.data = '/train_tiny_data/train_data/cifar100' elif 'imagenet' == self.args.train_dataset: self.args.data = '/train_data/imagenet' self.args.num_classes = 1000 self.args.weight_decay = 3e-5 self.args.report_freq = 100 self.args.init_channels = 50 self.args.drop_path_prob = 0 elif 'tiny-imagenet' == self.args.train_dataset: self.args.data = '/train_tiny_data/train_data/tiny-imagenet' self.args.num_classes = 200 elif 'food101' == self.args.train_dataset: self.args.data = '/train_tiny_data/train_data/food-101' self.args.num_classes = 101 self.args.init_channels = 48 def _init_log(self): self.args.save = '../logs/eval/' + self.args.arch + '/' + self.args.train_dataset + '/eval-{}-{}'.format(self.args.save, time.strftime('%Y%m%d-%H%M')) dutils.create_exp_dir(self.args.save, scripts_to_save=None) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(self.args.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logger = logging.getLogger('Architecture Training') logger.addHandler(fh) return logger def _init_random_and_device(self): # Set random seed and cuda device np.random.seed(self.args.seed) cudnn.benchmark = True torch.manual_seed(self.args.seed) cudnn.enabled = True torch.cuda.manual_seed(self.args.seed) max_free_gpu_id, gpus_info = dutils.get_gpus_memory_info() self.device_id = max_free_gpu_id self.gpus_info = gpus_info self.device = torch.device('cuda:{}'.format(0 if self.args.multi_gpus else self.device_id)) def _init_model(self): self.train_queue, self.valid_queue = self._load_dataset_queue() def _init_scheduler(): if 'cifar' in self.args.train_dataset: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, float(self.args.epochs)) else: scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, self.args.decay_period, gamma=self.args.gamma) return scheduler genotype = eval('geno_types.%s' % self.args.arch) reduce_level = (0 if 'cifar10' in self.args.train_dataset else 0) model = EvalNetwork(self.args.init_channels, self.args.num_classes, 0, self.args.layers, self.args.auxiliary, genotype, reduce_level) # Try move model to multi gpus if torch.cuda.device_count() > 1 and self.args.multi_gpus: self.logger.info('use: %d gpus', torch.cuda.device_count()) model = nn.DataParallel(model) else: self.logger.info('gpu device = %d' % self.device_id) torch.cuda.set_device(self.device_id) self.model = model.to(self.device) self.logger.info('param size = %fM', dutils.calc_parameters_count(model)) criterion = nn.CrossEntropyLoss() if self.args.num_classes >= 50: criterion = CrossEntropyLabelSmooth(self.args.num_classes, self.args.label_smooth) self.criterion = criterion.to(self.device) if self.args.opt == 'adam': self.optimizer = torch.optim.Adamax( model.parameters(), self.args.learning_rate, weight_decay=self.args.weight_decay ) elif self.args.opt == 'adabound': self.optimizer = AdaBound(model.parameters(), self.args.learning_rate, weight_decay=self.args.weight_decay) else: self.optimizer = torch.optim.SGD( model.parameters(), self.args.learning_rate, momentum=self.args.momentum, weight_decay=self.args.weight_decay ) self.best_acc_top1 = 0 # optionally resume from a checkpoint if self.args.resume: if os.path.isfile(self.args.resume): print("=> loading checkpoint {}".format(self.args.resume)) checkpoint = torch.load(self.args.resume) self.dur_time = checkpoint['dur_time'] self.args.start_epoch = checkpoint['epoch'] self.best_acc_top1 = checkpoint['best_acc_top1'] self.args.drop_path_prob = checkpoint['drop_path_prob'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format(self.args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(self.args.resume)) self.scheduler = _init_scheduler() # reload the scheduler if possible if self.args.resume and os.path.isfile(self.args.resume): checkpoint = torch.load(self.args.resume) self.scheduler.load_state_dict(checkpoint['scheduler']) def _load_dataset_queue(self): if 'cifar' in self.args.train_dataset: train_transform, valid_transform = dutils.data_transforms_cifar(self.args) if 'cifar10' == self.args.train_dataset: train_data = dset.CIFAR10(root=self.args.data, train=True, download=True, transform=train_transform) valid_data = dset.CIFAR10(root=self.args.data, train=False, download=True, transform=valid_transform) else: train_data = dset.CIFAR100(root=self.args.data, train=True, download=True, transform=train_transform) valid_data = dset.CIFAR100(root=self.args.data, train=False, download=True, transform=valid_transform) train_queue = torch.utils.data.DataLoader( train_data, batch_size = self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4 ) valid_queue = torch.utils.data.DataLoader( valid_data, batch_size = self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4 ) elif 'tiny-imagenet' == self.args.train_dataset: train_transform, valid_transform = dutils.data_transforms_tiny_imagenet() train_data = dartsdset.TinyImageNet200(self.args.data, train=True, download=True, transform=train_transform) valid_data = dartsdset.TinyImageNet200(self.args.data, train=False, download=True, transform=valid_transform) train_queue = torch.utils.data.DataLoader( train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4 ) valid_queue = torch.utils.data.DataLoader( valid_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4 ) elif 'imagenet' == self.args.train_dataset: traindir = os.path.join(self.args.data, 'train') validdir = os.path.join(self.args.data, 'val') train_transform, valid_transform = dutils.data_transforms_imagenet() train_data = dset.ImageFolder( traindir,train_transform) valid_data = dset.ImageFolder( validdir,valid_transform) train_queue = torch.utils.data.DataLoader( train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4) valid_queue = torch.utils.data.DataLoader( valid_data, batch_size=self.args.batch_size, shuffle=False, pin_memory=True, num_workers=4) elif 'food101' == self.args.train_dataset: traindir = os.path.join(self.args.data, 'train') validdir = os.path.join(self.args.data, 'val') train_transform, valid_transform = dutils.data_transforms_food101() train_data = dset.ImageFolder( traindir,train_transform) valid_data = dset.ImageFolder( validdir,valid_transform) train_queue = torch.utils.data.DataLoader( train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4) valid_queue = torch.utils.data.DataLoader( valid_data, batch_size=self.args.batch_size, shuffle=False, pin_memory=True, num_workers=4) return train_queue, valid_queue def run(self): self.logger.info('args = %s', self.args) run_start = time.time() for epoch in range(self.args.start_epoch, self.args.epochs): self.scheduler.step() self.logger.info('epoch % d / %d lr %e', epoch, self.args.epochs, self.scheduler.get_lr()[0]) if self.args.no_dropout: self.model._drop_path_prob = 0 else: self.model._drop_path_prob = self.args.drop_path_prob * epoch / self.args.epochs self.logger.info('drop_path_prob %e', self.model._drop_path_prob) train_acc, train_obj = self.train() self.logger.info('train loss %e, train acc %f', train_obj, train_acc) valid_acc_top1, valid_acc_top5, valid_obj = self.infer() self.logger.info('valid loss %e, top1 valid acc %f top5 valid acc %f', valid_obj, valid_acc_top1, valid_acc_top5) self.logger.info('best valid acc %f', self.best_acc_top1) is_best = False if valid_acc_top1 > self.best_acc_top1: self.best_acc_top1 = valid_acc_top1 is_best = True dutils.save_checkpoint({ 'epoch': epoch+1, 'dur_time': self.dur_time + time.time() - run_start, 'state_dict': self.model.state_dict(), 'drop_path_prob': self.args.drop_path_prob, 'best_acc_top1': self.best_acc_top1, 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict() }, is_best, self.args.save) self.logger.info('train epoches %d, best_acc_top1 %f, dur_time %s', self.args.epochs, self.best_acc_top1, dutils.calc_time(self.dur_time + time.time() - run_start)) def train(self): objs = dutils.AverageMeter() top1 = dutils.AverageMeter() top5 = dutils.AverageMeter() self.model.train() for step, (input, target) in enumerate(self.train_queue): input = input.cuda(self.device, non_blocking=True) target = target.cuda(self.device, non_blocking=True) self.optimizer.zero_grad() logits, logits_aux = self.model(input) loss = self.criterion(logits, target) if self.args.auxiliary: loss_aux = self.criterion(logits_aux, target) loss += self.args.auxiliary_weight*loss_aux loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip) self.optimizer.step() prec1, prec5 = dutils.accuracy(logits, target, topk=(1,5)) n = input.size(0) objs.update(loss.item(), n) top1.update(prec1.item(), n) top5.update(prec5.item(), n) if step % args.report_freq == 0: self.logger.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) return top1.avg, objs.avg def infer(self): objs = dutils.AverageMeter() top1 = dutils.AverageMeter() top5 = dutils.AverageMeter() self.model.eval() with torch.no_grad(): for step, (input, target) in enumerate(self.valid_queue): input = input.cuda(self.device, non_blocking=True) target = target.cuda(self.device, non_blocking=True) logits, _ = self.model(input) loss = self.criterion(logits, target) prec1, prec5 = dutils.accuracy(logits, target, topk=(1,5)) n = input.size(0) objs.update(loss.item(), n) top1.update(prec1.item(), n) top5.update(prec5.item(), n) if step % args.report_freq == 0: self.logger.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) return top1.avg, top5.avg, objs.avg
def main(): args = parse_args() update_config(cfg_hrnet, args) # create checkpoint dir if not isdir(args.checkpoint): mkdir_p(args.checkpoint) # create model #print('networks.'+ cfg_hrnet.MODEL.NAME+'.get_pose_net') model = eval('models.' + cfg_hrnet.MODEL.NAME + '.get_pose_net')( cfg_hrnet, is_train=True) model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() # show net args.channels = 3 args.height = cfg.data_shape[0] args.width = cfg.data_shape[1] #net_vision(model, args) # define loss function (criterion) and optimizer criterion = torch.nn.MSELoss(reduction='mean').cuda() #torch.optim.Adam optimizer = AdaBound(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) if args.resume: if isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) pretrained_dict = checkpoint['state_dict'] model.load_state_dict(pretrained_dict) args.start_epoch = checkpoint['epoch'] optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) logger = Logger(join(args.checkpoint, 'log.txt'), resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(join(args.checkpoint, 'log.txt')) logger.set_names(['Epoch', 'LR', 'Train Loss']) cudnn.benchmark = True torch.backends.cudnn.enabled = True print(' Total params: %.2fMB' % (sum(p.numel() for p in model.parameters()) / (1024 * 1024) * 4)) train_loader = torch.utils.data.DataLoader( #MscocoMulti(cfg), KPloader(cfg), batch_size=cfg.batch_size * len(args.gpus)) #, shuffle=True, #num_workers=args.workers, pin_memory=True) #for i, (img, targets, valid) in enumerate(train_loader): # print(i, img, targets, valid) for epoch in range(args.start_epoch, args.epochs): lr = adjust_learning_rate(optimizer, epoch, cfg.lr_dec_epoch, cfg.lr_gamma) print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) # train for one epoch train_loss = train(train_loader, model, criterion, optimizer) print('train_loss: ', train_loss) # append logger file logger.append([epoch + 1, lr, train_loss]) save_model( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, checkpoint=args.checkpoint) logger.close()
def train_model_v2_1(net, trainloader, validloader, epochs, lr, grad_accum_steps=1, warmup_epoch=1, patience=5, factor=0.5, opt='AdaBound', weight_decay=0.0, loss_w=[0.5, 0.25, 0.25], reference_labels=None, cb_beta=0.99, start_epoch=0, opt_state_dict=None): """ mixup, ReduceLROnPlateau, class balance """ net = net.cuda() # loss loss_w = loss_w if loss_w is not None else [0.5, 0.25, 0.25] if reference_labels is None: if len(loss_w) == 3: criterion = multiloss_wrapper_v1_mixup(loss_funcs=[ mixup.CrossEntropyLossForMixup(num_class=168), mixup.CrossEntropyLossForMixup(num_class=11), mixup.CrossEntropyLossForMixup(num_class=7) ], weights=loss_w) elif len(loss_w) == 4: criterion = multiloss_wrapper_v1_mixup(loss_funcs=[ mixup.CrossEntropyLossForMixup(num_class=168), mixup.CrossEntropyLossForMixup(num_class=11), mixup.CrossEntropyLossForMixup(num_class=7), mixup.CrossEntropyLossForMixup(num_class=1292) ], weights=loss_w) else: if len(loss_w) == 3: criterion = multiloss_wrapper_v1_mixup(loss_funcs=[ cbl.CB_CrossEntropyLoss(reference_labels[:, 0], num_class=168, beta=cb_beta, label_smooth=0.0), cbl.CB_CrossEntropyLoss(reference_labels[:, 1], num_class=11, beta=cb_beta, label_smooth=0.0), cbl.CB_CrossEntropyLoss(reference_labels[:, 2], num_class=7, beta=cb_beta, label_smooth=0.0) ], weights=loss_w) elif len(loss_w) == 4: criterion = multiloss_wrapper_v1_mixup(loss_funcs=[ cbl.CB_CrossEntropyLoss(reference_labels[:, 0], num_class=168, beta=cb_beta, label_smooth=0.0), cbl.CB_CrossEntropyLoss(reference_labels[:, 1], num_class=11, beta=cb_beta, label_smooth=0.0), cbl.CB_CrossEntropyLoss(reference_labels[:, 2], num_class=7, beta=cb_beta, label_smooth=0.0), cbl.CB_CrossEntropyLoss(reference_labels[:, 3], num_class=1292, beta=cb_beta, label_smooth=0.0) ], weights=loss_w) test_criterion = multiloss_wrapper_v1(loss_funcs=[ nn.CrossEntropyLoss(), nn.CrossEntropyLoss(), nn.CrossEntropyLoss() ], weights=loss_w) # opt if opt == 'SGD': optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9) elif opt == 'AdaBound': optimizer = AdaBound(net.parameters(), lr=lr, final_lr=0.1, weight_decay=weight_decay) # scheduler scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=patience, factor=factor, verbose=True) warmup_scheduler = WarmUpLR(optimizer, len(trainloader) * warmup_epoch) if opt_state_dict is not None: optimizer.load_state_dict(opt_state_dict) # train loglist = [] val_loss = 100 for epoch in range(start_epoch, epochs): if epoch > warmup_epoch - 1: scheduler.step(val_loss) print('epoch ', epoch) tr_log = _trainer_v1(net, trainloader, criterion, optimizer, epoch, grad_accum_steps, warmup_epoch, warmup_scheduler, use_mixup=True) vl_log = _tester_v1(net, validloader, test_criterion) loglist.append(list(tr_log) + list(vl_log)) val_loss = vl_log[0] save_checkpoint(epoch, net, optimizer, 'checkpoint') save_log(loglist, 'training_log.csv') return net
def train_model(cfg: DictConfig) -> None: output_dir = Path.cwd() logging.basicConfig(format='%(asctime)s\t%(levelname)s\t%(message)s', datefmt='%Y/%m/%d %H:%M:%S', filename=str(output_dir / 'log.txt'), level=logging.DEBUG) # hydraでlogがコンソールにも出力されてしまうのを抑制する logger = logging.getLogger() assert isinstance(logger.handlers[0], logging.StreamHandler) logger.handlers[0].setLevel(logging.CRITICAL) if cfg.gpu >= 0: device = torch.device(f"cuda:{cfg.gpu}") # noinspection PyUnresolvedReferences torch.backends.cudnn.benchmark = True else: device = torch.device("cpu") model = load_model(model_name=cfg.model_name) model.to(device) if cfg.swa.enable: swa_model = AveragedModel(model=model, device=device) else: swa_model = None # optimizer = optim.SGD( # model.parameters(), lr=cfg.optimizer.lr, # momentum=cfg.optimizer.momentum, # weight_decay=cfg.optimizer.weight_decay, # nesterov=cfg.optimizer.nesterov # ) optimizer = AdaBound(model.parameters(), lr=cfg.optimizer.lr, final_lr=cfg.optimizer.final_lr, weight_decay=cfg.optimizer.weight_decay, amsbound=False) scaler = torch.cuda.amp.GradScaler(enabled=cfg.use_amp) if cfg.scheduler.enable: scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer=optimizer, T_0=1, T_mult=1, eta_min=cfg.scheduler.eta_min) # scheduler = optim.lr_scheduler.CyclicLR( # optimizer, base_lr=cfg.scheduler.base_lr, # max_lr=cfg.scheduler.max_lr, # step_size_up=cfg.scheduler.step_size, # mode=cfg.scheduler.mode # ) else: scheduler = None if cfg.input_dir is not None: input_dir = Path(cfg.input_dir) model_path = input_dir / 'model.pt' print('load model from {}'.format(model_path)) model.load_state_dict(torch.load(model_path)) state_path = input_dir / 'state.pt' print('load optimizer state from {}'.format(state_path)) checkpoint = torch.load(state_path, map_location=device) epoch = checkpoint['epoch'] t = checkpoint['t'] optimizer.load_state_dict(checkpoint['optimizer']) if cfg.swa.enable and 'swa_model' in checkpoint: swa_model.load_state_dict(checkpoint['swa_model']) if cfg.scheduler.enable and 'scheduler' in checkpoint: scheduler.load_state_dict(checkpoint['scheduler']) if cfg.use_amp and 'scaler' in checkpoint: scaler.load_state_dict(checkpoint['scaler']) else: epoch = 0 t = 0 # カレントディレクトリが変更されるので、データのパスを修正 if isinstance(cfg.train_data, str): train_path_list = (hydra.utils.to_absolute_path(cfg.train_data), ) else: train_path_list = [ hydra.utils.to_absolute_path(path) for path in cfg.train_data ] logging.info('train data path: {}'.format(train_path_list)) train_data = load_train_data(path_list=train_path_list) train_dataset = train_data train_data = train_dataset[0] test_data = load_test_data( path=hydra.utils.to_absolute_path(cfg.test_data)) logging.info('train position num = {}'.format(len(train_data))) logging.info('test position num = {}'.format(len(test_data))) train_loader = DataLoader(train_data, device=device, batch_size=cfg.batch_size, shuffle=True) validation_loader = DataLoader(test_data[:cfg.test_batch_size * 10], device=device, batch_size=cfg.test_batch_size) test_loader = DataLoader(test_data, device=device, batch_size=cfg.test_batch_size) train_writer = SummaryWriter(log_dir=str(output_dir / 'train')) test_writer = SummaryWriter(log_dir=str(output_dir / 'test')) train_metrics = Metrics() eval_interval = cfg.eval_interval total_epoch = cfg.epoch + epoch for e in range(cfg.epoch): train_metrics_epoch = Metrics() model.train() desc = 'train [{:03d}/{:03d}]'.format(epoch + 1, total_epoch) train_size = len(train_loader) * 4 for x1, x2, t1, t2, z, value, mask in tqdm(train_loader, desc=desc): with torch.cuda.amp.autocast(enabled=cfg.use_amp): model.zero_grad() metric_value = compute_metric(model=model, x1=x1, x2=x2, t1=t1, t2=t2, z=z, value=value, mask=mask, val_lambda=cfg.val_lambda, beta=cfg.beta) scaler.scale(metric_value.loss).backward() if cfg.clip_grad_max_norm: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad_max_norm) scaler.step(optimizer) scaler.update() if cfg.swa.enable and t % cfg.swa.freq == 0: swa_model.update_parameters(model=model) t += 1 if cfg.scheduler.enable: scheduler.step(t / train_size) train_metrics.update(metric_value=metric_value) train_metrics_epoch.update(metric_value=metric_value) # print train loss if t % eval_interval == 0: model.eval() validation_metrics = Metrics() with torch.no_grad(): # noinspection PyAssignmentToLoopOrWithParameter for x1, x2, t1, t2, z, value, mask in validation_loader: m = compute_metric(model=model, x1=x1, x2=x2, t1=t1, t2=t2, z=z, value=value, mask=mask, val_lambda=cfg.val_lambda) validation_metrics.update(metric_value=m) last_lr = (scheduler.get_last_lr()[-1] if cfg.scheduler.enable else cfg.optimizer.lr) logging.info( 'epoch = {}, iteration = {}, lr = {}, {}, {}'.format( epoch + 1, t, last_lr, make_metric_log('train', train_metrics), make_metric_log('validation', validation_metrics))) write_summary(writer=train_writer, metrics=train_metrics, t=t, prefix='iteration') write_summary(writer=test_writer, metrics=validation_metrics, t=t, prefix='iteration') train_metrics = Metrics() train_writer.add_scalar('learning_rate', last_lr, global_step=t) model.train() elif t % cfg.train_log_interval == 0: last_lr = (scheduler.get_last_lr()[-1] if cfg.scheduler.enable else cfg.optimizer.lr) logging.info('epoch = {}, iteration = {}, lr = {}, {}'.format( epoch + 1, t, last_lr, make_metric_log('train', train_metrics))) write_summary(writer=train_writer, metrics=train_metrics, t=t, prefix='iteration') train_metrics = Metrics() train_writer.add_scalar('learning_rate', last_lr, global_step=t) if cfg.swa.enable: with torch.cuda.amp.autocast(enabled=cfg.use_amp): desc = 'update BN [{:03d}/{:03d}]'.format( epoch + 1, total_epoch) np.random.shuffle(train_data) # モーメントの計算にはそれなりのデータ数が必要 # 1/16に減らすより全部使ったほうが精度が高かった # データ量を10分程度で処理できる分量に制限 # メモリが連続でないとDataLoaderで正しく処理できないかもしれない train_data = np.ascontiguousarray(train_data[::4]) torch.optim.swa_utils.update_bn(loader=tqdm( hcpe_loader(data=train_data, device=device, batch_size=cfg.batch_size), desc=desc, total=len(train_data) // cfg.batch_size), model=swa_model) # print train loss for each epoch test_metrics = Metrics() if cfg.swa.enable: test_model = swa_model else: test_model = model test_model.eval() with torch.no_grad(): desc = 'test [{:03d}/{:03d}]'.format(epoch + 1, total_epoch) for x1, x2, t1, t2, z, value, mask in tqdm(test_loader, desc=desc): metric_value = compute_metric(model=test_model, x1=x1, x2=x2, t1=t1, t2=t2, z=z, value=value, mask=mask, val_lambda=cfg.val_lambda) test_metrics.update(metric_value=metric_value) logging.info('epoch = {}, iteration = {}, {}, {}'.format( epoch + 1, t, make_metric_log('train', train_metrics_epoch), make_metric_log('test', test_metrics))) write_summary(writer=train_writer, metrics=train_metrics_epoch, t=epoch + 1, prefix='epoch') write_summary(writer=test_writer, metrics=test_metrics, t=epoch + 1, prefix='epoch') epoch += 1 if e != cfg.epoch - 1: # 訓練データを入れ替える train_data = train_dataset[e + 1] train_loader.data = train_data train_writer.close() test_writer.close() print('save the model') torch.save(model.state_dict(), output_dir / 'model.pt') print('save the optimizer') state = {'epoch': epoch, 't': t, 'optimizer': optimizer.state_dict()} if cfg.scheduler.enable: state['scheduler'] = scheduler.state_dict() if cfg.swa.enable: state['swa_model'] = swa_model.state_dict() if cfg.use_amp: state['scaler'] = scaler.state_dict() torch.save(state, output_dir / 'state.pt')
def train(model, test_loader, lang, args, pairs, extra_loader): start = time.time() if args.optimizer == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9) elif args.optimizer == "adabound": optimizer = AdaBound(model.parameters(), lr=0.0001, final_lr=0.1) else: print("unknow optimizer.") exit(0) print_model(args) save_name = get_name(args) if not os.path.exists(args.save_path): os.mkdir(args.save_path) print("the model is saved in: "+save_name) n_epochs = args.n_epochs step = 0.0 begin_epoch = 0 best_val_bleu = 0 if not args.from_scratch: if os.path.exists(save_name): checkpoint = torch.load(save_name) model.load_state_dict(checkpoint['model_state_dict']) lr = checkpoint['lr'] step = checkpoint['step'] begin_epoch = checkpoint['epoch'] + 1 best_val_bleu = checkpoint['bleu'] optimizer.load_state_dict(checkpoint['optimizer_state_dict']) for param_group in optimizer.param_groups: param_group['lr'] = lr print("load successful!") checkpoint = [] else: print("load unsuccessful!") if args.use_dataset_B: extra_iter = iter(extra_loader) num_iter = 1 if args.ratio >= 2: num_iter = int(args.ratio) else: extra_iter = None num_iter = 0 for epoch in range(begin_epoch ,n_epochs): model.train() train_loader = Data.DataLoader(pairs, batch_size=args.batch_size, shuffle=True) print_loss_total, step, extra_iter, lr = train_epoch(model, lang, args, train_loader, extra_loader, extra_iter, num_iter, optimizer, step) print('total loss: %f'%(print_loss_total)) model.eval() curr_bleu = evaluate(model, test_loader, lang, args.max_length) print('%s (epoch: %d %d%%)' % (timeSince(start, (epoch+1-begin_epoch)/(n_epochs-begin_epoch)), epoch, (epoch+1-begin_epoch)/(n_epochs-begin_epoch)*100)) if curr_bleu > best_val_bleu: best_val_bleu = curr_bleu torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'lr': lr, 'step': step, 'epoch': epoch, 'bleu': curr_bleu, }, save_name) print("checkpoint saved!") print()