def train(args): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model, optimizer = get_model(args) model.train() train_loader, _ = data_loader(args) with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw: config = json.dumps(vars(args), indent=4, separators=(',', ':')) fw.write(config) fw.write('#epoch,loss,pred@1,pred@5\n') total_epoch = args.epoch global_counter = args.global_counter current_epoch = args.current_epoch end = time.time() max_iter = total_epoch * len(train_loader) print('Max iter:', max_iter) while current_epoch < total_epoch: model.train() losses.reset() top1.reset() top5.reset() batch_time.reset() res = my_optim.reduce_lr(args, optimizer, current_epoch) if res: for g in optimizer.param_groups: out_str = 'Epoch:%d, %f\n' % (current_epoch, g['lr']) fw.write(out_str) steps_per_epoch = len(train_loader) for idx, dat in enumerate(train_loader): img_path, img, label = dat global_counter += 1 img, label = img.cuda(), label.cuda() img_var, label_var = Variable(img), Variable(label) logits = model(img_var, label_var) loss_val, = model.module.get_loss(logits, label_var) optimizer.zero_grad() loss_val.backward() optimizer.step() if not args.onehot == 'True': logits1 = torch.squeeze(logits[0]) prec1_1, prec5_1 = Metrics.accuracy(logits1.data, label.long(), topk=(1, 5)) top1.update(prec1_1[0], img.size()[0]) top5.update(prec5_1[0], img.size()[0]) losses.update(loss_val.data[0], img.size()[0]) batch_time.update(time.time() - end) end = time.time() if global_counter % 1000 == 0: losses.reset() top1.reset() top5.reset() if global_counter % args.disp_interval == 0: # Calculate ETA eta_seconds = ( (total_epoch - current_epoch) * steps_per_epoch + (steps_per_epoch - idx)) * batch_time.avg eta_str = "{:0>8}".format( datetime.timedelta(seconds=int(eta_seconds))) eta_seconds_epoch = steps_per_epoch * batch_time.avg eta_str_epoch = "{:0>8}".format( datetime.timedelta(seconds=int(eta_seconds_epoch))) print( 'Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'ETA {eta_str}({eta_str_epoch})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( current_epoch, global_counter % len(train_loader), len(train_loader), batch_time=batch_time, eta_str=eta_str, eta_str_epoch=eta_str_epoch, loss=losses, top1=top1, top5=top5)) if current_epoch % 1 == 0: save_checkpoint(args, { 'epoch': current_epoch, 'arch': 'resnet', 'global_counter': global_counter, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, is_best=False, filename='%s_epoch_%d_glo_step_%d.pth.tar' % (args.dataset, current_epoch, global_counter)) with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw: fw.write('%d,%.4f,%.3f,%.3f\n' % (current_epoch, losses.avg, top1.avg, top5.avg)) current_epoch += 1
def train(args): batch_time = AverageMeter() losses = AverageMeter() total_epoch = args.epoch global_counter = args.global_counter current_epoch = args.current_epoch train_loader = train_data_loader_iam(args) max_step = total_epoch * len(train_loader) args.max_step = max_step print('Max step:', max_step) model, optimizer, criterion = get_model(args) print(model) model.train() end = time.time() while current_epoch < total_epoch: model.train() losses.reset() batch_time.reset() res = my_optim.reduce_lr(args, optimizer, current_epoch) steps_per_epoch = len(train_loader) for idx, dat in enumerate(train_loader): img, label = dat label = label.cuda(non_blocking=True) logits = model(img) if len(logits.shape) == 1: logits = logits.reshape(label.shape) loss_val = criterion(logits, label) optimizer.zero_grad() loss_val.backward() optimizer.step() losses.update(loss_val.data.item(), img.size()[0]) batch_time.update(time.time() - end) end = time.time() global_counter += 1 if global_counter % 1000 == 0: losses.reset() if global_counter % args.disp_interval == 0: print('Epoch: [{}][{}/{}]\t' 'LR: {:.5f}\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( current_epoch, global_counter % len(train_loader), len(train_loader), optimizer.param_groups[0]['lr'], loss=losses)) if current_epoch == args.epoch - 1: save_checkpoint(args, { 'epoch': current_epoch, 'global_counter': global_counter, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, is_best=False, filename='%s_epoch_%d.pth' % (args.dataset, current_epoch)) current_epoch += 1
def train(args): batch_time = AverageMeter() losses = AverageMeter() total_epoch = args.epoch global_counter = args.global_counter current_epoch = args.current_epoch train_loader, val_loader = train_data_loader(args) max_step = total_epoch*len(train_loader) args.max_step = max_step print('Max step:', max_step) model, optimizer = get_model(args) model.train() print(model) end = time.time() while current_epoch < total_epoch: model.train() losses.reset() batch_time.reset() res = my_optim.reduce_lr(args, optimizer, current_epoch) steps_per_epoch = len(train_loader) index = 0 for idx, dat in enumerate(train_loader): img_name1, img1, label1, img_name2, img2, label2, img_name3, img3, label3 = dat label1 = label1.cuda(non_blocking=True) label2 = label2.cuda(non_blocking=True) label3 = label3.cuda(non_blocking=True) x11, x1, x22,x2, x33,x3 = model(img1, img2, img3, current_epoch, label1, index) index += 1 loss_train = 0.4 * (F.multilabel_soft_margin_loss(x11, label1) + F.multilabel_soft_margin_loss(x22, label2) + F.multilabel_soft_margin_loss(x33, label3)) + (F.multilabel_soft_margin_loss(x1, label1) + F.multilabel_soft_margin_loss(x2, label2) + F.multilabel_soft_margin_loss(x3, label3)) optimizer.zero_grad() loss_train.backward() optimizer.step() losses.update(loss_train.data.item(), img.size()[0]) batch_time.update(time.time() - end) end = time.time() global_counter += 1 if global_counter % 1000 == 0: losses.reset() if global_counter % args.disp_interval == 0: print('Epoch: [{}][{}/{}]\t' 'LR: {:.5f}\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( current_epoch, global_counter%len(train_loader), len(train_loader), optimizer.param_groups[0]['lr'], loss=losses)) if current_epoch == args.epoch-1: save_checkpoint(args, { 'epoch': current_epoch, 'global_counter': global_counter, 'state_dict':model.state_dict(), 'optimizer':optimizer.state_dict() }, is_best=False, filename='%s_epoch_%d.pth' %(args.dataset, current_epoch)) current_epoch += 1
def train(args): batch_time = AverageMeter() losses = AverageMeter() loss_cls = AverageMeter() loss_dist = AverageMeter() loss_aux = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model, optimizer = get_model(args) model.train() train_loader, _ = data_loader(args) with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw: config = json.dumps(vars(args), indent=4, separators=(',', ':')) fw.write(config) fw.write('#epoch,loss,pred@1,pred@5\n') total_epoch = args.epoch global_counter = args.global_counter current_epoch = args.current_epoch end = time.time() max_iter = total_epoch * len(train_loader) print('Max iter:', max_iter) while current_epoch < total_epoch: model.train() losses.reset() loss_cls.reset() loss_dist.reset() loss_aux.reset() top1.reset() top5.reset() batch_time.reset() res = my_optim.reduce_lr(args, optimizer, current_epoch) if res: for g in optimizer.param_groups: out_str = 'Epoch:%d, %f\n' % (current_epoch, g['lr']) with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw: fw.write(out_str) steps_per_epoch = len(train_loader) for idx, dat in enumerate(train_loader): img_path, img, label = dat global_counter += 1 img, label = img.cuda(), label.cuda() img_var, label_var = Variable(img), Variable(label) logits = model(img_var, label_var) loss_list = model.module.get_loss(logits, label_var) loss_val = loss_list[0] optimizer.zero_grad() loss_val.backward() optimizer.step() losses.update(loss_val.data.item(), img.size(0)) loss_cls.update(loss_list[1].data.item(), img.size(0)) loss_dist.update(loss_list[2].data.item(), img.size(0)) loss_aux.update(loss_list[3].data.item(), img.size(0)) batch_time.update(time.time() - end) end = time.time() if global_counter % 1000 == 0: losses.reset() top1.reset() top5.reset() loss_cls.reset() loss_dist.reset() loss_aux.reset() batch_time.reset() if global_counter % args.disp_interval == 0: # Calculate ETA # eta_seconds = ((total_epoch - current_epoch)*steps_per_epoch + (steps_per_epoch - idx))*batch_time.avg print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Loss_cls {loss_c.val:.4f} ({loss_c.avg:.4f})\t' 'Loss_dist {loss_d.val:.4f} ({loss_d.avg:.4f})\t' 'Loss_aux {loss_aux.val:.4f} ({loss_aux.avg:.4f})\t'. format(current_epoch, global_counter % len(train_loader), len(train_loader), batch_time=batch_time, loss=losses, loss_c=loss_cls, loss_d=loss_dist, loss_aux=loss_aux)) if current_epoch % args.save_interval == 0: model_stat_dict = model.module.state_dict() save_checkpoint(args, { 'epoch': current_epoch, 'arch': 'resnet', 'global_counter': global_counter, 'state_dict': model_stat_dict, 'optimizer': optimizer.state_dict(), 'center_feat_bank': model.module.center_feat_bank }, is_best=False, filename='%s_epoch_%d_glo_step_%d.pth' % (args.dataset, current_epoch, global_counter)) with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw: fw.write('%d,%.4f,%.3f,%.3f\n' % (current_epoch, losses.avg, top1.avg, top5.avg)) current_epoch += 1
def train(args): batch_time = AverageMeter() lossCos = AverageMeter() losses = AverageMeter() loss_root = AverageMeter() loss_parent = AverageMeter() loss_child = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() top1_parent = AverageMeter() top5_parent = AverageMeter() top1_root = AverageMeter() top5_root = AverageMeter() model, optimizer = get_model(args) model.train() train_loader, _, _ = data_loader(args) with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw: config = json.dumps(vars(args), indent=4, separators=(',', ':')) fw.write(config) fw.write('#epoch \t loss \t pred@1 \t pred@5\n') # construct writer if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(log_dir=args.log_dir) total_epoch = args.epoch global_counter = args.global_counter current_epoch = args.current_epoch end = time.time() max_iter = total_epoch * len(train_loader) print('Max iter:', max_iter) while current_epoch < total_epoch: model.train() lossCos.reset() losses.reset() loss_root.reset() loss_parent.reset() loss_child.reset() top1.reset() top5.reset() top1_parent.reset() top5_parent.reset() top1_root.reset() top5_root.reset() batch_time.reset() res = my_optim.reduce_lr(args, optimizer, current_epoch) if res: with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw: for g in optimizer.param_groups: out_str = 'Epoch:%d, %f\n' % (current_epoch, g['lr']) fw.write(out_str) steps_per_epoch = len(train_loader) for idx, dat in enumerate(train_loader): img_path, img, label = dat global_counter += 1 img, root_label, parent_label, child_label = img.cuda( ), label[0].cuda(), label[1].cuda(), label[2].cuda() img_var, root_label_var, parent_label_var, child_label_var = Variable( img), Variable(root_label), Variable(parent_label), Variable( child_label) logits = model(img_var) loss_val, loss_root_val, loss_parent_val, loss_child_val, lossCos_val = model.module.get_loss( logits, root_label_var, parent_label_var, child_label_var) # write into tensorboard writer.add_scalar('loss_val', loss_val, global_counter) # network parameter update optimizer.zero_grad() loss_val.backward() optimizer.step() if not args.onehot == 'True': logits5 = torch.squeeze(logits[-1]) prec1, prec5 = evaluate.accuracy(logits5.data, child_label.long(), topk=(1, 5)) top1.update(prec1[0], img.size()[0]) top5.update(prec5[0], img.size()[0]) logits4 = torch.squeeze(logits[-2]) prec1_4, prec5_4 = evaluate.accuracy(logits4.data, parent_label.long(), topk=(1, 5)) top1_parent.update(prec1_4[0], img.size()[0]) top5_parent.update(prec5_4[0], img.size()[0]) logits3 = torch.squeeze(logits[-3]) prec1_3, prec5_3 = evaluate.accuracy(logits3.data, root_label.long(), topk=(1, 5)) top1_root.update(prec1_3[0], img.size()[0]) top5_root.update(prec5_3[0], img.size()[0]) losses.update(loss_val.data, img.size()[0]) loss_root.update(loss_root_val.data, img.size()[0]) loss_parent.update(loss_parent_val.data, img.size()[0]) loss_child.update(loss_child_val.data, img.size()[0]) lossCos.update(lossCos_val.data, img.size()[0]) batch_time.update(time.time() - end) end = time.time() if global_counter % args.disp_interval == 0: # Calculate ETA eta_seconds = ( (total_epoch - current_epoch) * steps_per_epoch + (steps_per_epoch - idx)) * batch_time.avg eta_str = "{:0>8}".format( datetime.timedelta(seconds=int(eta_seconds))) eta_seconds_epoch = steps_per_epoch * batch_time.avg eta_str_epoch = "{:0>8}".format( datetime.timedelta(seconds=int(eta_seconds_epoch))) print( 'Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'ETA {eta_str}({eta_str_epoch})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Loss cos {lossCos.val:.4f} ({lossCos.avg:.4f})\t' 'Loss root {loss_parent.val:.4f} ({loss_root.avg:.4f})\t' 'Loss parent {loss_parent.val:.4f} ({loss_parent.avg:.4f})\t' 'Loss child {loss_child.val:.4f} ({loss_child.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t' 'parent: Prec@1 {top1_parent.val:.3f} ({top1_parent.avg:.3f})\t' 'Prec@5 {top5_parent.val:.3f} ({top5_parent.avg:.3f})\t' 'root: Prec@1 {top1_root.val:.3f} ({top1_root.avg:.3f})\t' 'Prec@5 {top5_root.val:.3f} ({top5_root.avg:.3f})'.format( current_epoch, global_counter % len(train_loader), len(train_loader), batch_time=batch_time, eta_str=eta_str, eta_str_epoch=eta_str_epoch, loss=losses, loss_root=loss_root, loss_parent=loss_parent, loss_child=loss_child, top1=top1, top5=top5, top1_parent=top1_parent, top5_parent=top5_parent, top1_root=top1_root, top5_root=top5_root, lossCos=lossCos, )) plotter.plot('rootLoss', 'train', current_epoch, loss_root.avg) plotter.plot('childLoss', 'train', current_epoch, loss_child.avg) plotter.plot('parentLoss', 'train', current_epoch, loss_parent.avg) plotter.plot('cosLoss', 'train', current_epoch, lossCos.avg) plotter.plot('top1', 'train', current_epoch, top1.avg) plotter.plot('top5', 'train', current_epoch, top5.avg) plotter.plot('parent Top1', 'train', current_epoch, top1_parent.avg) plotter.plot('parent Top5', 'train', current_epoch, top5_parent.avg) plotter.plot('root Top1', 'train', current_epoch, top1_root.avg) plotter.plot('root Top5', 'train', current_epoch, top5_root.avg) current_epoch += 1 if current_epoch % 50 == 0: save_checkpoint(args, { 'epoch': current_epoch, 'arch': 'resnet', 'global_counter': global_counter, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, is_best=False, filename='%s_epoch_%d_glo_step_%d.pth.tar' % (args.dataset, current_epoch, global_counter)) with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw: fw.write( '%d \t %.4f \t %.4f \t %.4f \t %.4f \t %.4f \t %.3f \t %.3f\t %.3f \t %.3f\t %.3f \t %.3f\n' % (current_epoch, losses.avg, lossCos.avg, loss_root.avg, loss_parent.avg, loss_child.avg, top1_root.avg, top5_root.avg, top1_parent.avg, top5_parent.avg, top1.avg, top5.avg)) losses.reset() loss_root.reset() loss_parent.reset() loss_child.reset() top1.reset() top5.reset() top1_parent.reset() top5_parent.reset() top1_root.reset() top5_root.reset() lossCos.reset()
def train(args): batch_time = AverageMeter() losses = AverageMeter() losses1 = AverageMeter() losses2 = AverageMeter() losses2_1 = AverageMeter() losses2_2 = AverageMeter() losses3_1 = AverageMeter() losses3_2 = AverageMeter() losses4_1 = AverageMeter() losses4_2 = AverageMeter() losses1_comple = AverageMeter() losses2_comple = AverageMeter() total_epoch = args.epoch train_loader, val_loader = train_data_loader_siamese_more_augumentation( args) # train_loader, val_loader = train_data_loader_normal_resize(args) max_step = total_epoch * len(train_loader) args.max_step = max_step print('Max step:', max_step) model, optimizer = get_model(args) print(model) global_counter = args.global_counter print("here: ", global_counter) current_epoch = args.current_epoch model.train() end = time.time() while current_epoch < total_epoch: losses.reset() losses1.reset() losses2.reset() losses2_1.reset() losses2_2.reset() losses3_1.reset() losses3_2.reset() losses4_1.reset() losses4_2.reset() losses1_comple.reset() losses2_comple.reset() batch_time.reset() res = my_optim.reduce_lr(args, optimizer, current_epoch) steps_per_epoch = len(train_loader) validate(model, val_loader) model.train() ## prepare for training index = 0 for idx, dat in enumerate(train_loader): _, _, input1, input2, input1_transforms, label1, label2 = dat # print(type(input1_transforms),len(input1_transforms),input1_transforms[0].size()) if random.random() < 0.0: # print(input1.size()) input1 = hide_patch(input1) input2 = hide_patch(input2) input1_transforms = [hide_patch(i) for i in input1_transforms] img = [input1, input2] label = torch.cat([label1, label2]) img2 = [input1, input1_transforms[0]] img3 = [input1, input1_transforms[1]] img4 = [input1, input1_transforms[2]] # print(input1.size(),input2.size(),img.size()) # print(torch.max(input1),torch.min(input1)) # print(label.size(),img.size()) # label = label.cuda(non_blocking=True) # label1 = label1.cuda(non_blocking=True) # label2 = label2.cuda(non_blocking=True) label = label.cuda() label1 = label1.cuda() label2 = label2.cuda() label_new = label1 + label2 label_new[label_new != 2] = 0 label_new[label_new == 2] = 1 label1_comple = label1 - label_new label2_comple = label2 - label_new assert (label1_comple >= 0).all() and (label2_comple >= 0).all() label_new = torch.cat([label_new, label_new]) # print(label1[0],label2[0],label_new[0]) logits, co_logits = model(img, current_epoch, label, None) logits2, co_logits2 = model(img2, current_epoch, label, None) logits3, co_logits3 = model(img3, current_epoch, label, None) logits4, co_logits4 = model(img4, current_epoch, label, None) index += args.batch_size if logits is None: print("here") continue if len(logits.shape) == 1: logits = logits.reshape(label.shape) # print(logits.size(),label.size(),img.size()) # loss_val1 = F.multilabel_soft_margin_loss(logits[:input1.size(0)], label[:input1.size(0)]) loss_val1 = F.multilabel_soft_margin_loss(logits, label) loss_val2 = F.multilabel_soft_margin_loss( co_logits[:2 * input1.size(0)], label_new) loss_val1_comple = F.multilabel_soft_margin_loss( co_logits[2 * input1.size(0):3 * input1.size(0)], label1_comple) loss_val2_comple = F.multilabel_soft_margin_loss( co_logits[3 * input1.size(0):], label2_comple) loss_val2_1 = F.multilabel_soft_margin_loss( logits2, torch.cat([label1, label1])) loss_val2_2 = F.multilabel_soft_margin_loss( co_logits2[:2 * input1.size(0)], torch.cat([label1, label1])) loss_val3_1 = F.multilabel_soft_margin_loss( logits3, torch.cat([label1, label1])) loss_val3_2 = F.multilabel_soft_margin_loss( co_logits3[:2 * input1.size(0)], torch.cat([label1, label1])) loss_val4_1 = F.multilabel_soft_margin_loss( logits4, torch.cat([label1, label1])) loss_val4_2 = F.multilabel_soft_margin_loss( co_logits4[:2 * input1.size(0)], torch.cat([label1, label1])) # print(loss_val,loss_val2) ## use co-attention loss_val = loss_val1 + loss_val2 + loss_val2_1 + loss_val2_2 + loss_val3_1 + loss_val3_2 + loss_val4_1 + loss_val4_2 ## don't use co-attention # loss_val=loss_val1+loss_val2_1+loss_val3_1+loss_val4_1 # loss_val=loss_val4_1+loss_val4_2 # print(loss_val) # print(logits.size()) # print(logits[0]) # print(label[0]) if current_epoch >= 2: if (label1_comple > 0).any(): loss_val = loss_val + loss_val1_comple if (label2_comple > 0).any(): loss_val = loss_val + loss_val2_comple optimizer.zero_grad() loss_val.backward() optimizer.step() # print(loss_val.data.item()) losses.update(loss_val.data.item(), input1.size()[0] + input2.size()[0]) losses1.update(loss_val1.data.item(), input1.size()[0] + input2.size()[0]) losses2.update(loss_val2.data.item(), input1.size()[0] + input2.size()[0]) losses2_1.update(loss_val2_1.data.item(), input1.size()[0] + input2.size()[0]) losses2_2.update(loss_val2_2.data.item(), input1.size()[0] + input2.size()[0]) losses3_1.update(loss_val3_1.data.item(), input1.size()[0] + input2.size()[0]) losses3_2.update(loss_val3_2.data.item(), input1.size()[0] + input2.size()[0]) losses4_1.update(loss_val4_1.data.item(), input1.size()[0] + input2.size()[0]) losses4_2.update(loss_val4_2.data.item(), input1.size()[0] + input2.size()[0]) if (label1_comple > 0).any(): losses1_comple.update(loss_val1_comple.data.item(), input1.size()[0] + input2.size()[0]) if (label2_comple > 0).any(): losses2_comple.update(loss_val2_comple.data.item(), input1.size()[0] + input2.size()[0]) batch_time.update(time.time() - end) end = time.time() global_counter += 1 if global_counter % 1000 == 0: losses.reset() losses1.reset() losses2.reset() losses2_1.reset() losses2_2.reset() losses3_1.reset() losses3_2.reset() losses4_1.reset() losses4_2.reset() losses1_comple.reset() losses2_comple.reset() if global_counter % args.disp_interval == 0: print('Epoch: [{}][{}/{}]\t' 'LR: {:.5f}\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( current_epoch, global_counter % len(train_loader), len(train_loader), optimizer.param_groups[0]['lr'], loss=losses)) print(losses.avg, losses1.avg, losses2.avg, losses2_1.avg, losses2_2.avg, losses3_1.avg, losses3_2.avg, losses4_1.avg, losses4_2.avg, losses1_comple.avg, losses2_comple.avg) # if current_epoch == args.epoch-1: save_checkpoint(args, { 'epoch': current_epoch, 'global_counter': global_counter, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, is_best=False, filename='%s_epoch_%d.pth' % (args.dataset, current_epoch)) current_epoch += 1
def train(args): os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus_str # for reproducibility if args.seed is not None: np.random.seed(args.seed) random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) cudnn.benchmark = False cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') else: cudnn.benchmark = True print('Running parameters:\n') print(json.dumps(vars(args), indent=4, separators=(',', ':'))) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw: config = json.dumps(vars(args), indent=4, separators=(',', ':')) fw.write(config) log_head = '#epoch \t loss \t pred@1 \t pred@5' batch_time = AverageMeter() losses = AverageMeter() if args.ram: losses_ra = AverageMeter() log_head += 'loss_ra \t ' log_head += '\n' with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw: fw.write(log_head) top1 = AverageMeter() top5 = AverageMeter() args.device = torch.device('cuda') if args.gpus[0] >= 0 else torch.device( 'cpu') model, optimizer = get_model(args) model.train() train_loader = data_loader(args) # construct writer if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(log_dir=args.log_dir) total_epoch = args.epoch global_counter = args.global_counter current_epoch = args.current_epoch end = time.time() max_iter = total_epoch * len(train_loader) print('Max iter:', max_iter) while current_epoch < total_epoch: model.train() losses.reset() if args.ram: losses_ra.reset() top1.reset() top5.reset() batch_time.reset() res = my_optim.reduce_lr(args, optimizer, current_epoch) if res: with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw: for g in optimizer.param_groups: out_str = 'Epoch:%d, %f\n' % (current_epoch, g['lr']) fw.write(out_str) steps_per_epoch = len(train_loader) for idx, dat in enumerate(train_loader): img_path, img, label = dat global_counter += 1 img, label = img.to(args.device), label.to(args.device) logits, _, _ = model(img) loss_val, loss_ra = model.module.get_loss(logits, label, epoch=current_epoch, ram_start=args.ram_start) # write into tensorboard writer.add_scalar('loss_val', loss_val, global_counter) # network parameter update optimizer.zero_grad() # if args.mixp: # with amp.scale_loss(loss_val, optimizer) as scaled_loss: # scaled_loss.backward() # else: loss_val.backward() optimizer.step() logits = torch.mean(torch.mean(logits, dim=2), dim=2) if not args.onehot == 'True': prec1, prec5 = evaluate.accuracy(logits.data, label.long(), topk=(1, 5)) top1.update(prec1[0], img.size()[0]) top5.update(prec5[0], img.size()[0]) losses.update(loss_val.data, img.size()[0]) if args.ram: losses_ra.update(loss_ra.data, img.size()[0]) batch_time.update(time.time() - end) end = time.time() if global_counter % args.disp_interval == 0: # Calculate ETA eta_seconds = ( (total_epoch - current_epoch) * steps_per_epoch + (steps_per_epoch - idx)) * batch_time.avg eta_str = "{:0>8}".format( str(datetime.timedelta(seconds=int(eta_seconds)))) eta_seconds_epoch = steps_per_epoch * batch_time.avg eta_str_epoch = "{:0>8}".format( str(datetime.timedelta(seconds=int(eta_seconds_epoch)))) log_output= 'Epoch: [{0}][{1}/{2}] \t ' \ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t ' \ 'ETA {eta_str}({eta_str_epoch})\t ' \ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t ' \ 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t ' \ 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format(current_epoch, global_counter % len(train_loader), len(train_loader), batch_time=batch_time, eta_str=eta_str, eta_str_epoch=eta_str_epoch, loss=losses, top1=top1, top5=top5) if args.ram: log_output += 'Loss_ra {loss_ra.val:.4f} ({loss_ra.avg:.4f})\t'.format( loss_ra=losses_ra) print(log_output) writer.add_scalar('top1', top1.avg, global_counter) writer.add_scalar('top5', top5.avg, global_counter) current_epoch += 1 if current_epoch % 10 == 0: save_checkpoint(args, { 'epoch': current_epoch, 'arch': args.arch, 'global_counter': global_counter, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, is_best=False, filename='%s_epoch_%d.pth.tar' % (args.dataset, current_epoch)) with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw: log_output = '{} \t {:.4f} \t {:.3f} \t {:.3f} \t'.format( current_epoch, losses.avg, top1.avg, top5.avg) if args.ram: log_output += '{:.4f}'.format(losses_ra.avg) log_output += '\n' fw.write(log_output) losses.reset() if args.ram: losses_ra.reset() top1.reset() top5.reset()